OpenSound commited on
Commit
b9d6819
·
verified ·
1 Parent(s): 2c654bd

Upload 84 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. api.py +117 -0
  2. src/.idea/.gitignore +8 -0
  3. src/.idea/inspectionProfiles/Project_Default.xml +34 -0
  4. src/.idea/inspectionProfiles/profiles_settings.xml +6 -0
  5. src/.idea/misc.xml +7 -0
  6. src/.idea/modules.xml +8 -0
  7. src/.idea/src.iml +12 -0
  8. src/.idea/workspace.xml +128 -0
  9. src/inference.py +169 -0
  10. src/models/blocks.py +325 -0
  11. src/models/conditioners.py +180 -0
  12. src/models/udit.py +356 -0
  13. src/models/utils/.ipynb_checkpoints/__init__-checkpoint.py +0 -0
  14. src/models/utils/.ipynb_checkpoints/attention-checkpoint.py +290 -0
  15. src/models/utils/.ipynb_checkpoints/modules-checkpoint.py +374 -0
  16. src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py +91 -0
  17. src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py +146 -0
  18. src/models/utils/.ipynb_checkpoints/timm-checkpoint.py +114 -0
  19. src/models/utils/__init__.py +0 -0
  20. src/models/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  21. src/models/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  22. src/models/utils/__pycache__/attention.cpython-310.pyc +0 -0
  23. src/models/utils/__pycache__/attention.cpython-311.pyc +0 -0
  24. src/models/utils/__pycache__/modules.cpython-310.pyc +0 -0
  25. src/models/utils/__pycache__/modules.cpython-311.pyc +0 -0
  26. src/models/utils/__pycache__/rotary.cpython-310.pyc +0 -0
  27. src/models/utils/__pycache__/rotary.cpython-311.pyc +0 -0
  28. src/models/utils/__pycache__/span_mask.cpython-310.pyc +0 -0
  29. src/models/utils/__pycache__/span_mask.cpython-311.pyc +0 -0
  30. src/models/utils/__pycache__/timm.cpython-310.pyc +0 -0
  31. src/models/utils/__pycache__/timm.cpython-311.pyc +0 -0
  32. src/models/utils/attention.py +290 -0
  33. src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py +99 -0
  34. src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py +74 -0
  35. src/models/utils/bk/__pycache__/rotary.cpython-311.pyc +0 -0
  36. src/models/utils/bk/attention.py +99 -0
  37. src/models/utils/bk/llama_rotary.py +74 -0
  38. src/models/utils/modules.py +374 -0
  39. src/models/utils/rotary.py +91 -0
  40. src/models/utils/span_mask.py +146 -0
  41. src/models/utils/timm.py +114 -0
  42. src/modules/autoencoder_wrapper.py +83 -0
  43. src/modules/clap_wrapper.py +0 -0
  44. src/modules/dac/__init__.py +16 -0
  45. src/modules/dac/__main__.py +36 -0
  46. src/modules/dac/compare/__init__.py +0 -0
  47. src/modules/dac/compare/encodec.py +54 -0
  48. src/modules/dac/model/__init__.py +4 -0
  49. src/modules/dac/model/base.py +294 -0
  50. src/modules/dac/model/dac.py +364 -0
api.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import gradio as gr
6
+ import soundfile as sf
7
+ from transformers import T5Tokenizer, T5EncoderModel
8
+ from diffusers import DDIMScheduler
9
+ from src.models.conditioners import MaskDiT
10
+ from src.modules.autoencoder_wrapper import Autoencoder
11
+ from src.inference import inference
12
+ from src.utils import load_yaml_with_includes
13
+
14
+
15
+ # Load model and configs
16
+ def load_models(config_name, ckpt_path, vae_path, device):
17
+ params = load_yaml_with_includes(config_name)
18
+
19
+ # Load codec model
20
+ autoencoder = Autoencoder(ckpt_path=vae_path,
21
+ model_type=params['autoencoder']['name'],
22
+ quantization_first=params['autoencoder']['q_first']).to(device)
23
+ autoencoder.eval()
24
+
25
+ # Load text encoder
26
+ tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model'])
27
+ text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model']).to(device)
28
+ text_encoder.eval()
29
+
30
+ # Load main U-Net model
31
+ unet = MaskDiT(**params['model']).to(device)
32
+ unet.load_state_dict(torch.load(ckpt_path)['model'])
33
+ unet.eval()
34
+
35
+ # Load noise scheduler
36
+ noise_scheduler = DDIMScheduler(**params['diff'])
37
+
38
+ return autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params
39
+
40
+ MAX_SEED = np.iinfo(np.int32).max
41
+
42
+ # Model and config paths
43
+ config_name = 'ckpts/ezaudio-xl.yml'
44
+ ckpt_path = 'ckpts/s3/ezaudio_s3_xl.pt'
45
+ vae_path = 'ckpts/vae/1m.pt'
46
+ save_path = 'output/'
47
+ os.makedirs(save_path, exist_ok=True)
48
+
49
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
50
+
51
+ autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params = load_models(config_name, ckpt_path, vae_path,
52
+ device)
53
+
54
+ latents = torch.randn((1, 128, 128), device=device)
55
+ noise = torch.randn_like(latents)
56
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=device)
57
+ _ = noise_scheduler.add_noise(latents, noise, timesteps)
58
+
59
+
60
+ # Inference function
61
+ def generate_audio(text, length,
62
+ guidance_scale, guidance_rescale, ddim_steps, eta,
63
+ random_seed, randomize_seed):
64
+ neg_text = None
65
+ length = length * params['autoencoder']['latent_sr']
66
+
67
+ if randomize_seed:
68
+ random_seed = random.randint(0, MAX_SEED)
69
+
70
+ pred = inference(autoencoder, unet, None, None,
71
+ tokenizer, text_encoder,
72
+ params, noise_scheduler,
73
+ text, neg_text,
74
+ length,
75
+ guidance_scale, guidance_rescale,
76
+ ddim_steps, eta, random_seed,
77
+ device)
78
+
79
+ pred = pred.cpu().numpy().squeeze(0).squeeze(0)
80
+ # output_file = f"{save_path}/{text}.wav"
81
+ # sf.write(output_file, pred, samplerate=params['autoencoder']['sr'])
82
+
83
+ return params['autoencoder']['sr'], pred
84
+
85
+
86
+ # Gradio Interface
87
+ def gradio_interface():
88
+ # Input components
89
+ text_input = gr.Textbox(label="Text Prompt", value="the sound of dog barking")
90
+ length_input = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Audio Length (in seconds)")
91
+
92
+ # Advanced settings
93
+ guidance_scale_input = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=5, label="Guidance Scale")
94
+ guidance_rescale_input = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0.75, label="Guidance Rescale")
95
+ ddim_steps_input = gr.Slider(minimum=25, maximum=200, step=5, value=100, label="DDIM Steps")
96
+ eta_input = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="Eta")
97
+ random_seed_input = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=0,)
98
+
99
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
100
+
101
+ # Output component
102
+ output_audio = gr.Audio(label="Converted Audio", type="numpy")
103
+
104
+ # Interface
105
+ gr.Interface(
106
+ fn=generate_audio,
107
+ inputs=[text_input, length_input, guidance_scale_input, guidance_rescale_input, ddim_steps_input, eta_input,
108
+ random_seed_input, randomize_seed],
109
+ outputs=output_audio,
110
+ title="EzAudio Text-to-Audio Generator",
111
+ description="Generate audio from text using a diffusion model. Adjust advanced settings for more control.",
112
+ allow_flagging="never"
113
+ ).launch()
114
+
115
+
116
+ if __name__ == "__main__":
117
+ gradio_interface()
src/.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
src/.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="21">
8
+ <item index="0" class="java.lang.String" itemvalue="numba" />
9
+ <item index="1" class="java.lang.String" itemvalue="scipy" />
10
+ <item index="2" class="java.lang.String" itemvalue="decorator" />
11
+ <item index="3" class="java.lang.String" itemvalue="six" />
12
+ <item index="4" class="java.lang.String" itemvalue="joblib" />
13
+ <item index="5" class="java.lang.String" itemvalue="threadpoolctl" />
14
+ <item index="6" class="java.lang.String" itemvalue="scikit-learn" />
15
+ <item index="7" class="java.lang.String" itemvalue="python-dateutil" />
16
+ <item index="8" class="java.lang.String" itemvalue="cffi" />
17
+ <item index="9" class="java.lang.String" itemvalue="SoundFile" />
18
+ <item index="10" class="java.lang.String" itemvalue="audioread" />
19
+ <item index="11" class="java.lang.String" itemvalue="kiwisolver" />
20
+ <item index="12" class="java.lang.String" itemvalue="cycler" />
21
+ <item index="13" class="java.lang.String" itemvalue="llvmlite" />
22
+ <item index="14" class="java.lang.String" itemvalue="mido" />
23
+ <item index="15" class="java.lang.String" itemvalue="matplotlib" />
24
+ <item index="16" class="java.lang.String" itemvalue="resampy" />
25
+ <item index="17" class="java.lang.String" itemvalue="librosa" />
26
+ <item index="18" class="java.lang.String" itemvalue="pyparsing" />
27
+ <item index="19" class="java.lang.String" itemvalue="pretty-midi" />
28
+ <item index="20" class="java.lang.String" itemvalue="Pillow" />
29
+ </list>
30
+ </value>
31
+ </option>
32
+ </inspection_tool>
33
+ </profile>
34
+ </component>
src/.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
src/.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.10" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
7
+ </project>
src/.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/src.iml" filepath="$PROJECT_DIR$/.idea/src.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
src/.idea/src.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
src/.idea/workspace.xml ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="AutoImportSettings">
4
+ <option name="autoReloadType" value="SELECTIVE" />
5
+ </component>
6
+ <component name="ChangeListManager">
7
+ <list default="true" id="cb82860d-7ce6-451e-932b-96d3a6e7b20d" name="Changes" comment="" />
8
+ <option name="SHOW_DIALOG" value="false" />
9
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
10
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
11
+ <option name="LAST_RESOLUTION" value="IGNORE" />
12
+ </component>
13
+ <component name="ProjectColorInfo"><![CDATA[{
14
+ "associatedIndex": 4
15
+ }]]></component>
16
+ <component name="ProjectId" id="2m8UaG5ZprDRDpwT0ASOoKWJVNg" />
17
+ <component name="ProjectViewState">
18
+ <option name="hideEmptyMiddlePackages" value="true" />
19
+ <option name="showLibraryContents" value="true" />
20
+ </component>
21
+ <component name="PropertiesComponent"><![CDATA[{
22
+ "keyToString": {
23
+ "Python.api.executor": "Run",
24
+ "Python.clean.executor": "Run",
25
+ "Python.gradio.executor": "Run",
26
+ "RunOnceActivity.ShowReadmeOnStart": "true",
27
+ "node.js.detected.package.eslint": "true",
28
+ "node.js.detected.package.tslint": "true",
29
+ "node.js.selected.package.eslint": "(autodetect)",
30
+ "node.js.selected.package.tslint": "(autodetect)",
31
+ "nodejs_package_manager_path": "npm",
32
+ "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable",
33
+ "vue.rearranger.settings.migration": "true"
34
+ }
35
+ }]]></component>
36
+ <component name="RunManager" selected="Python.api">
37
+ <configuration name="api" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
38
+ <module name="src" />
39
+ <option name="ENV_FILES" value="" />
40
+ <option name="INTERPRETER_OPTIONS" value="" />
41
+ <option name="PARENT_ENVS" value="true" />
42
+ <option name="SDK_HOME" value="" />
43
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/.." />
44
+ <option name="IS_MODULE_SDK" value="false" />
45
+ <option name="ADD_CONTENT_ROOTS" value="true" />
46
+ <option name="ADD_SOURCE_ROOTS" value="true" />
47
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
48
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/../api.py" />
49
+ <option name="PARAMETERS" value="" />
50
+ <option name="SHOW_COMMAND_LINE" value="false" />
51
+ <option name="EMULATE_TERMINAL" value="false" />
52
+ <option name="MODULE_MODE" value="false" />
53
+ <option name="REDIRECT_INPUT" value="false" />
54
+ <option name="INPUT_FILE" value="" />
55
+ <method v="2" />
56
+ </configuration>
57
+ <configuration name="clean" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
58
+ <module name="src" />
59
+ <option name="ENV_FILES" value="" />
60
+ <option name="INTERPRETER_OPTIONS" value="" />
61
+ <option name="PARENT_ENVS" value="true" />
62
+ <option name="SDK_HOME" value="" />
63
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/../ckpts/vae" />
64
+ <option name="IS_MODULE_SDK" value="false" />
65
+ <option name="ADD_CONTENT_ROOTS" value="true" />
66
+ <option name="ADD_SOURCE_ROOTS" value="true" />
67
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
68
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/../ckpts/vae/clean.py" />
69
+ <option name="PARAMETERS" value="" />
70
+ <option name="SHOW_COMMAND_LINE" value="false" />
71
+ <option name="EMULATE_TERMINAL" value="false" />
72
+ <option name="MODULE_MODE" value="false" />
73
+ <option name="REDIRECT_INPUT" value="false" />
74
+ <option name="INPUT_FILE" value="" />
75
+ <method v="2" />
76
+ </configuration>
77
+ <configuration name="gradio" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
78
+ <module name="src" />
79
+ <option name="ENV_FILES" value="" />
80
+ <option name="INTERPRETER_OPTIONS" value="" />
81
+ <option name="PARENT_ENVS" value="true" />
82
+ <option name="SDK_HOME" value="" />
83
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/.." />
84
+ <option name="IS_MODULE_SDK" value="false" />
85
+ <option name="ADD_CONTENT_ROOTS" value="true" />
86
+ <option name="ADD_SOURCE_ROOTS" value="true" />
87
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
88
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/../gradio.py" />
89
+ <option name="PARAMETERS" value="" />
90
+ <option name="SHOW_COMMAND_LINE" value="false" />
91
+ <option name="EMULATE_TERMINAL" value="false" />
92
+ <option name="MODULE_MODE" value="false" />
93
+ <option name="REDIRECT_INPUT" value="false" />
94
+ <option name="INPUT_FILE" value="" />
95
+ <method v="2" />
96
+ </configuration>
97
+ <recent_temporary>
98
+ <list>
99
+ <item itemvalue="Python.api" />
100
+ <item itemvalue="Python.gradio" />
101
+ <item itemvalue="Python.clean" />
102
+ </list>
103
+ </recent_temporary>
104
+ </component>
105
+ <component name="SharedIndexes">
106
+ <attachedChunks>
107
+ <set>
108
+ <option value="bundled-js-predefined-1d06a55b98c1-74d2a5396914-JavaScript-PY-241.14494.241" />
109
+ <option value="bundled-python-sdk-0509580d9d50-28c9f5db9ffe-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-241.14494.241" />
110
+ </set>
111
+ </attachedChunks>
112
+ </component>
113
+ <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
114
+ <component name="TaskManager">
115
+ <task active="true" id="Default" summary="Default task">
116
+ <changelist id="cb82860d-7ce6-451e-932b-96d3a6e7b20d" name="Changes" comment="" />
117
+ <created>1726457759523</created>
118
+ <option name="number" value="Default" />
119
+ <option name="presentableId" value="Default" />
120
+ <updated>1726457759523</updated>
121
+ <workItem from="1726457760668" duration="3668000" />
122
+ </task>
123
+ <servers />
124
+ </component>
125
+ <component name="TypeScriptGeneratedFilesManager">
126
+ <option name="version" value="3" />
127
+ </component>
128
+ </project>
src/inference.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import pandas as pd
4
+ import torch
5
+ import librosa
6
+ import numpy as np
7
+ import soundfile as sf
8
+ from tqdm import tqdm
9
+ from utils import scale_shift_re
10
+
11
+
12
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
13
+ """
14
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
15
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
16
+ """
17
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
18
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
19
+ # rescale the results from guidance (fixes overexposure)
20
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
21
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
22
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
23
+ return noise_cfg
24
+
25
+
26
+ @torch.no_grad()
27
+ def inference(autoencoder, unet, gt, gt_mask,
28
+ tokenizer, text_encoder,
29
+ params, noise_scheduler,
30
+ text_raw, neg_text=None,
31
+ audio_frames=500,
32
+ guidance_scale=3, guidance_rescale=0.0,
33
+ ddim_steps=50, eta=1, random_seed=2024,
34
+ device='cuda',
35
+ ):
36
+ if neg_text is None:
37
+ neg_text = [""]
38
+ if tokenizer is not None:
39
+ text_batch = tokenizer(text_raw,
40
+ max_length=params['text_encoder']['max_length'],
41
+ padding="max_length", truncation=True, return_tensors="pt")
42
+ text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
43
+ text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
44
+
45
+ uncond_text_batch = tokenizer(neg_text,
46
+ max_length=params['text_encoder']['max_length'],
47
+ padding="max_length", truncation=True, return_tensors="pt")
48
+ uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
49
+ uncond_text = text_encoder(input_ids=uncond_text,
50
+ attention_mask=uncond_text_mask).last_hidden_state
51
+ else:
52
+ text, text_mask = None, None
53
+ guidance_scale = None
54
+
55
+ codec_dim = params['model']['out_chans']
56
+ unet.eval()
57
+
58
+ if random_seed is not None:
59
+ generator = torch.Generator(device=device).manual_seed(random_seed)
60
+ else:
61
+ generator = torch.Generator(device=device)
62
+ generator.seed()
63
+
64
+ noise_scheduler.set_timesteps(ddim_steps)
65
+
66
+ # init noise
67
+ noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
68
+ latents = noise
69
+
70
+ for t in noise_scheduler.timesteps:
71
+ latents = noise_scheduler.scale_model_input(latents, t)
72
+
73
+ if guidance_scale:
74
+
75
+ latents_combined = torch.cat([latents, latents], dim=0)
76
+ text_combined = torch.cat([text, uncond_text], dim=0)
77
+ text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
78
+
79
+ if gt is not None:
80
+ gt_combined = torch.cat([gt, gt], dim=0)
81
+ gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
82
+ else:
83
+ gt_combined = None
84
+ gt_mask_combined = None
85
+
86
+ output_combined, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
87
+ cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined)
88
+ output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
89
+
90
+ output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
91
+ if guidance_rescale > 0.0:
92
+ output_pred = rescale_noise_cfg(output_pred, output_text,
93
+ guidance_rescale=guidance_rescale)
94
+ else:
95
+ output_pred, mae_mask = unet(latents, t, text, context_mask=text_mask,
96
+ cls_token=None, gt=gt, mae_mask_infer=gt_mask)
97
+
98
+ latents = noise_scheduler.step(model_output=output_pred, timestep=t,
99
+ sample=latents,
100
+ eta=eta, generator=generator).prev_sample
101
+
102
+ pred = scale_shift_re(latents, params['autoencoder']['scale'],
103
+ params['autoencoder']['shift'])
104
+ if gt is not None:
105
+ pred[~gt_mask] = gt[~gt_mask]
106
+ pred_wav = autoencoder(embedding=pred)
107
+ return pred_wav
108
+
109
+
110
+ @torch.no_grad()
111
+ def eval_udit(autoencoder, unet,
112
+ tokenizer, text_encoder,
113
+ params, noise_scheduler,
114
+ val_df, subset,
115
+ audio_frames, mae=False,
116
+ guidance_scale=3, guidance_rescale=0.0,
117
+ ddim_steps=50, eta=1, random_seed=2023,
118
+ device='cuda',
119
+ epoch=0, save_path='logs/eval/', val_num=5):
120
+ val_df = pd.read_csv(val_df)
121
+ val_df = val_df[val_df['split'] == subset]
122
+ if mae:
123
+ val_df = val_df[val_df['audio_length'] != 0]
124
+
125
+ save_path = save_path + str(epoch) + '/'
126
+ os.makedirs(save_path, exist_ok=True)
127
+
128
+ for i in tqdm(range(len(val_df))):
129
+ row = val_df.iloc[i]
130
+ text = [row['caption']]
131
+ if mae:
132
+ audio_path = params['data']['val_dir'] + str(row['audio_path'])
133
+ gt, sr = librosa.load(audio_path, sr=params['data']['sr'])
134
+ gt = gt / (np.max(np.abs(gt)) + 1e-9)
135
+ sf.write(save_path + text[0] + '_gt.wav', gt, samplerate=params['data']['sr'])
136
+ num_samples = 10 * sr
137
+ if len(gt) < num_samples:
138
+ padding = num_samples - len(gt)
139
+ gt = np.pad(gt, (0, padding), 'constant')
140
+ else:
141
+ gt = gt[:num_samples]
142
+ gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device)
143
+ gt = autoencoder(audio=gt)
144
+ B, D, L = gt.shape
145
+ mask_len = int(L * 0.2)
146
+ gt_mask = torch.zeros(B, D, L).to(device)
147
+ for _ in range(2):
148
+ start = random.randint(0, L - mask_len)
149
+ gt_mask[:, :, start:start + mask_len] = 1
150
+ gt_mask = gt_mask.bool()
151
+ else:
152
+ gt = None
153
+ gt_mask = None
154
+
155
+ pred = inference(autoencoder, unet, gt, gt_mask,
156
+ tokenizer, text_encoder,
157
+ params, noise_scheduler,
158
+ text, neg_text=None,
159
+ audio_frames=audio_frames,
160
+ guidance_scale=guidance_scale, guidance_rescale=guidance_rescale,
161
+ ddim_steps=ddim_steps, eta=eta, random_seed=random_seed,
162
+ device=device)
163
+
164
+ pred = pred.cpu().numpy().squeeze(0).squeeze(0)
165
+
166
+ sf.write(save_path + text[0] + '.wav', pred, samplerate=params['data']['sr'])
167
+
168
+ if i + 1 >= val_num:
169
+ break
src/models/blocks.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+ from .utils.attention import Attention, JointAttention
5
+ from .utils.modules import unpatchify, FeedForward
6
+ from .utils.modules import film_modulate
7
+
8
+
9
+ class AdaLN(nn.Module):
10
+ def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
11
+ super().__init__()
12
+ self.ada_mode = ada_mode
13
+ self.scale_shift_table = None
14
+ if ada_mode == 'ada':
15
+ # move nn.silu outside
16
+ self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
17
+ elif ada_mode == 'ada_single':
18
+ # adaln used in pixel-art alpha
19
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
20
+ elif ada_mode in ['ada_lora', 'ada_lora_bias']:
21
+ self.lora_a = nn.Linear(dim, r * 6, bias=False)
22
+ self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
23
+ self.scaling = alpha / r
24
+ if ada_mode == 'ada_lora_bias':
25
+ # take bias out for consistency
26
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
27
+ else:
28
+ raise NotImplementedError
29
+
30
+ def forward(self, time_token=None, time_ada=None):
31
+ if self.ada_mode == 'ada':
32
+ assert time_ada is None
33
+ B = time_token.shape[0]
34
+ time_ada = self.time_ada(time_token).reshape(B, 6, -1)
35
+ elif self.ada_mode == 'ada_single':
36
+ B = time_ada.shape[0]
37
+ time_ada = time_ada.reshape(B, 6, -1)
38
+ time_ada = self.scale_shift_table[None] + time_ada
39
+ elif self.ada_mode in ['ada_lora', 'ada_lora_bias']:
40
+ B = time_ada.shape[0]
41
+ time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
42
+ time_ada = time_ada + time_ada_lora
43
+ time_ada = time_ada.reshape(B, 6, -1)
44
+ if self.scale_shift_table is not None:
45
+ time_ada = self.scale_shift_table[None] + time_ada
46
+ else:
47
+ raise NotImplementedError
48
+ return time_ada
49
+
50
+
51
+ class DiTBlock(nn.Module):
52
+ """
53
+ A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
54
+ """
55
+
56
+ def __init__(self, dim, context_dim=None,
57
+ num_heads=8, mlp_ratio=4.,
58
+ qkv_bias=False, qk_scale=None, qk_norm=None,
59
+ act_layer='gelu', norm_layer=nn.LayerNorm,
60
+ time_fusion='none',
61
+ ada_lora_rank=None, ada_lora_alpha=None,
62
+ skip=False, skip_norm=False,
63
+ rope_mode='none',
64
+ context_norm=False,
65
+ use_checkpoint=False):
66
+
67
+ super().__init__()
68
+ self.norm1 = norm_layer(dim)
69
+ self.attn = Attention(dim=dim,
70
+ num_heads=num_heads,
71
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
72
+ qk_norm=qk_norm,
73
+ rope_mode=rope_mode)
74
+
75
+ if context_dim is not None:
76
+ self.use_context = True
77
+ self.cross_attn = Attention(dim=dim,
78
+ num_heads=num_heads,
79
+ context_dim=context_dim,
80
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
81
+ qk_norm=qk_norm,
82
+ rope_mode='none')
83
+ self.norm2 = norm_layer(dim)
84
+ if context_norm:
85
+ self.norm_context = norm_layer(context_dim)
86
+ else:
87
+ self.norm_context = nn.Identity()
88
+ else:
89
+ self.use_context = False
90
+
91
+ self.norm3 = norm_layer(dim)
92
+ self.mlp = FeedForward(dim=dim, mult=mlp_ratio,
93
+ activation_fn=act_layer, dropout=0)
94
+
95
+ self.use_adanorm = True if time_fusion != 'token' else False
96
+ if self.use_adanorm:
97
+ self.adaln = AdaLN(dim, ada_mode=time_fusion,
98
+ r=ada_lora_rank, alpha=ada_lora_alpha)
99
+ if skip:
100
+ self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity()
101
+ self.skip_linear = nn.Linear(2 * dim, dim)
102
+ else:
103
+ self.skip_linear = None
104
+
105
+ self.use_checkpoint = use_checkpoint
106
+
107
+ def forward(self, x, time_token=None, time_ada=None,
108
+ skip=None, context=None,
109
+ x_mask=None, context_mask=None, extras=None):
110
+ if self.use_checkpoint:
111
+ return checkpoint(self._forward, x,
112
+ time_token, time_ada, skip, context,
113
+ x_mask, context_mask, extras,
114
+ use_reentrant=False)
115
+ else:
116
+ return self._forward(x,
117
+ time_token, time_ada, skip, context,
118
+ x_mask, context_mask, extras)
119
+
120
+ def _forward(self, x, time_token=None, time_ada=None,
121
+ skip=None, context=None,
122
+ x_mask=None, context_mask=None, extras=None):
123
+ B, T, C = x.shape
124
+ if self.skip_linear is not None:
125
+ assert skip is not None
126
+ cat = torch.cat([x, skip], dim=-1)
127
+ cat = self.skip_norm(cat)
128
+ x = self.skip_linear(cat)
129
+
130
+ if self.use_adanorm:
131
+ time_ada = self.adaln(time_token, time_ada)
132
+ (shift_msa, scale_msa, gate_msa,
133
+ shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
134
+
135
+ # self attention
136
+ if self.use_adanorm:
137
+ x_norm = film_modulate(self.norm1(x), shift=shift_msa,
138
+ scale=scale_msa)
139
+ x = x + (1 - gate_msa) * self.attn(x_norm, context=None,
140
+ context_mask=x_mask,
141
+ extras=extras)
142
+ else:
143
+ x = x + self.attn(self.norm1(x), context=None, context_mask=x_mask,
144
+ extras=extras)
145
+
146
+ # cross attention
147
+ if self.use_context:
148
+ assert context is not None
149
+ x = x + self.cross_attn(x=self.norm2(x),
150
+ context=self.norm_context(context),
151
+ context_mask=context_mask, extras=extras)
152
+
153
+ # mlp
154
+ if self.use_adanorm:
155
+ x_norm = film_modulate(self.norm3(x), shift=shift_mlp, scale=scale_mlp)
156
+ x = x + (1 - gate_mlp) * self.mlp(x_norm)
157
+ else:
158
+ x = x + self.mlp(self.norm3(x))
159
+
160
+ return x
161
+
162
+
163
+ class JointDiTBlock(nn.Module):
164
+ """
165
+ A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
166
+ """
167
+
168
+ def __init__(self, dim, context_dim=None,
169
+ num_heads=8, mlp_ratio=4.,
170
+ qkv_bias=False, qk_scale=None, qk_norm=None,
171
+ act_layer='gelu', norm_layer=nn.LayerNorm,
172
+ time_fusion='none',
173
+ ada_lora_rank=None, ada_lora_alpha=None,
174
+ skip=(False, False),
175
+ rope_mode=False,
176
+ context_norm=False,
177
+ use_checkpoint=False,):
178
+
179
+ super().__init__()
180
+ # no cross attention
181
+ assert context_dim is None
182
+ self.attn_norm_x = norm_layer(dim)
183
+ self.attn_norm_c = norm_layer(dim)
184
+ self.attn = JointAttention(dim=dim,
185
+ num_heads=num_heads,
186
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
187
+ qk_norm=qk_norm,
188
+ rope_mode=rope_mode)
189
+ self.ffn_norm_x = norm_layer(dim)
190
+ self.ffn_norm_c = norm_layer(dim)
191
+ self.mlp_x = FeedForward(dim=dim, mult=mlp_ratio,
192
+ activation_fn=act_layer, dropout=0)
193
+ self.mlp_c = FeedForward(dim=dim, mult=mlp_ratio,
194
+ activation_fn=act_layer, dropout=0)
195
+
196
+ # Zero-out the shift table
197
+ self.use_adanorm = True if time_fusion != 'token' else False
198
+ if self.use_adanorm:
199
+ self.adaln = AdaLN(dim, ada_mode=time_fusion,
200
+ r=ada_lora_rank, alpha=ada_lora_alpha)
201
+
202
+ if skip is False:
203
+ skip_x, skip_c = False, False
204
+ else:
205
+ skip_x, skip_c = skip
206
+
207
+ self.skip_linear_x = nn.Linear(2 * dim, dim) if skip_x else None
208
+ self.skip_linear_c = nn.Linear(2 * dim, dim) if skip_c else None
209
+
210
+ self.use_checkpoint = use_checkpoint
211
+
212
+ def forward(self, x, time_token=None, time_ada=None,
213
+ skip=None, context=None,
214
+ x_mask=None, context_mask=None, extras=None):
215
+ if self.use_checkpoint:
216
+ return checkpoint(self._forward, x,
217
+ time_token, time_ada, skip,
218
+ context, x_mask, context_mask, extras,
219
+ use_reentrant=False)
220
+ else:
221
+ return self._forward(x,
222
+ time_token, time_ada, skip,
223
+ context, x_mask, context_mask, extras)
224
+
225
+ def _forward(self, x, time_token=None, time_ada=None,
226
+ skip=None, context=None,
227
+ x_mask=None, context_mask=None, extras=None):
228
+
229
+ assert context is None and context_mask is None
230
+
231
+ context, x = x[:, :extras, :], x[:, extras:, :]
232
+ context_mask, x_mask = x_mask[:, :extras], x_mask[:, extras:]
233
+
234
+ if skip is not None:
235
+ skip_c, skip_x = skip[:, :extras, :], skip[:, extras:, :]
236
+
237
+ B, T, C = x.shape
238
+ if self.skip_linear_x is not None:
239
+ x = self.skip_linear_x(torch.cat([x, skip_x], dim=-1))
240
+
241
+ if self.skip_linear_c is not None:
242
+ context = self.skip_linear_c(torch.cat([context, skip_c], dim=-1))
243
+
244
+ if self.use_adanorm:
245
+ time_ada = self.adaln(time_token, time_ada)
246
+ (shift_msa, scale_msa, gate_msa,
247
+ shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
248
+
249
+ # self attention
250
+ x_norm = self.attn_norm_x(x)
251
+ c_norm = self.attn_norm_c(context)
252
+ if self.use_adanorm:
253
+ x_norm = film_modulate(x_norm, shift=shift_msa, scale=scale_msa)
254
+ x_out, c_out = self.attn(x_norm, context=c_norm,
255
+ x_mask=x_mask, context_mask=context_mask,
256
+ extras=extras)
257
+ if self.use_adanorm:
258
+ x = x + (1 - gate_msa) * x_out
259
+ else:
260
+ x = x + x_out
261
+ context = context + c_out
262
+
263
+ # mlp
264
+ if self.use_adanorm:
265
+ x_norm = film_modulate(self.ffn_norm_x(x),
266
+ shift=shift_mlp, scale=scale_mlp)
267
+ x = x + (1 - gate_mlp) * self.mlp_x(x_norm)
268
+ else:
269
+ x = x + self.mlp_x(self.ffn_norm_x(x))
270
+
271
+ c_norm = self.ffn_norm_c(context)
272
+ context = context + self.mlp_c(c_norm)
273
+
274
+ return torch.cat((context, x), dim=1)
275
+
276
+
277
+ class FinalBlock(nn.Module):
278
+ def __init__(self, embed_dim, patch_size, in_chans,
279
+ img_size,
280
+ input_type='2d',
281
+ norm_layer=nn.LayerNorm,
282
+ use_conv=True,
283
+ use_adanorm=True):
284
+ super().__init__()
285
+ self.in_chans = in_chans
286
+ self.img_size = img_size
287
+ self.input_type = input_type
288
+
289
+ self.norm = norm_layer(embed_dim)
290
+ if use_adanorm:
291
+ self.use_adanorm = True
292
+ else:
293
+ self.use_adanorm = False
294
+
295
+ if input_type == '2d':
296
+ self.patch_dim = patch_size ** 2 * in_chans
297
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
298
+ if use_conv:
299
+ self.final_layer = nn.Conv2d(self.in_chans, self.in_chans,
300
+ 3, padding=1)
301
+ else:
302
+ self.final_layer = nn.Identity()
303
+
304
+ elif input_type == '1d':
305
+ self.patch_dim = patch_size * in_chans
306
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
307
+ if use_conv:
308
+ self.final_layer = nn.Conv1d(self.in_chans, self.in_chans,
309
+ 3, padding=1)
310
+ else:
311
+ self.final_layer = nn.Identity()
312
+
313
+ def forward(self, x, time_ada=None, extras=0):
314
+ B, T, C = x.shape
315
+ x = x[:, extras:, :]
316
+ # only handle generation target
317
+ if self.use_adanorm:
318
+ shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
319
+ x = film_modulate(self.norm(x), shift, scale)
320
+ else:
321
+ x = self.norm(x)
322
+ x = self.linear(x)
323
+ x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
324
+ x = self.final_layer(x)
325
+ return x
src/models/conditioners.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import repeat
5
+ import math
6
+ from .udit import UDiT
7
+ from .utils.span_mask import compute_mask_indices
8
+
9
+
10
+ class EmbeddingCFG(nn.Module):
11
+ """
12
+ Handles label dropout for classifier-free guidance.
13
+ """
14
+ # todo: support 2D input
15
+
16
+ def __init__(self, in_channels):
17
+ super().__init__()
18
+ self.cfg_embedding = nn.Parameter(
19
+ torch.randn(in_channels) / in_channels ** 0.5)
20
+
21
+ def token_drop(self, condition, condition_mask, cfg_prob):
22
+ """
23
+ Drops labels to enable classifier-free guidance.
24
+ """
25
+ b, t, device = condition.shape[0], condition.shape[1], condition.device
26
+ drop_ids = torch.rand(b, device=device) < cfg_prob
27
+ uncond = repeat(self.cfg_embedding, "c -> b t c", b=b, t=t)
28
+ condition = torch.where(drop_ids[:, None, None], uncond, condition)
29
+ if condition_mask is not None:
30
+ condition_mask[drop_ids] = False
31
+ condition_mask[drop_ids, 0] = True
32
+
33
+ return condition, condition_mask
34
+
35
+ def forward(self, condition, condition_mask, cfg_prob=0.0):
36
+ if condition_mask is not None:
37
+ condition_mask = condition_mask.clone()
38
+ if cfg_prob > 0:
39
+ condition, condition_mask = self.token_drop(condition,
40
+ condition_mask,
41
+ cfg_prob)
42
+ return condition, condition_mask
43
+
44
+
45
+ class DiscreteCFG(nn.Module):
46
+ def __init__(self, replace_id=2):
47
+ super(DiscreteCFG, self).__init__()
48
+ self.replace_id = replace_id
49
+
50
+ def forward(self, context, context_mask, cfg_prob):
51
+ context = context.clone()
52
+ if context_mask is not None:
53
+ context_mask = context_mask.clone()
54
+ if cfg_prob > 0:
55
+ cfg_mask = torch.rand(len(context)) < cfg_prob
56
+ if torch.any(cfg_mask):
57
+ context[cfg_mask] = 0
58
+ context[cfg_mask, 0] = self.replace_id
59
+ if context_mask is not None:
60
+ context_mask[cfg_mask] = False
61
+ context_mask[cfg_mask, 0] = True
62
+ return context, context_mask
63
+
64
+
65
+ class CFGModel(nn.Module):
66
+ def __init__(self, context_dim, backbone):
67
+ super().__init__()
68
+ self.model = backbone
69
+ self.context_cfg = EmbeddingCFG(context_dim)
70
+
71
+ def forward(self, x, timesteps,
72
+ context, x_mask=None, context_mask=None,
73
+ cfg_prob=0.0):
74
+ context = self.context_cfg(context, cfg_prob)
75
+ x = self.model(x=x, timesteps=timesteps,
76
+ context=context,
77
+ x_mask=x_mask, context_mask=context_mask)
78
+ return x
79
+
80
+
81
+ class ConcatModel(nn.Module):
82
+ def __init__(self, backbone, in_dim, stride=[]):
83
+ super().__init__()
84
+ self.model = backbone
85
+
86
+ self.downsample_layers = nn.ModuleList()
87
+ for i, s in enumerate(stride):
88
+ downsample_layer = nn.Conv1d(
89
+ in_dim,
90
+ in_dim * 2,
91
+ kernel_size=2 * s,
92
+ stride=s,
93
+ padding=math.ceil(s / 2),
94
+ )
95
+ self.downsample_layers.append(downsample_layer)
96
+ in_dim = in_dim * 2
97
+
98
+ self.context_cfg = EmbeddingCFG(in_dim)
99
+
100
+ def forward(self, x, timesteps,
101
+ context, x_mask=None,
102
+ cfg=False, cfg_prob=0.0):
103
+
104
+ # todo: support 2D input
105
+ # x: B, C, L
106
+ # context: B, C, L
107
+
108
+ for downsample_layer in self.downsample_layers:
109
+ context = downsample_layer(context)
110
+
111
+ context = context.transpose(1, 2)
112
+ context = self.context_cfg(caption=context,
113
+ cfg=cfg, cfg_prob=cfg_prob)
114
+ context = context.transpose(1, 2)
115
+
116
+ assert context.shape[-1] == x.shape[-1]
117
+ x = torch.cat([context, x], dim=1)
118
+ x = self.model(x=x, timesteps=timesteps,
119
+ context=None, x_mask=x_mask, context_mask=None)
120
+ return x
121
+
122
+
123
+ class MaskDiT(nn.Module):
124
+ def __init__(self, mae=False, mae_prob=0.5, mask_ratio=[0.25, 1.0], mask_span=10, **kwargs):
125
+ super().__init__()
126
+ self.model = UDiT(**kwargs)
127
+ self.mae = mae
128
+ if self.mae:
129
+ out_channel = kwargs.pop('out_chans', None)
130
+ self.mask_embed = nn.Parameter(torch.zeros((out_channel)))
131
+ self.mae_prob = mae_prob
132
+ self.mask_ratio = mask_ratio
133
+ self.mask_span = mask_span
134
+
135
+ def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
136
+ B, D, L = gt.shape
137
+ if mae_mask_infer is None:
138
+ # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
139
+ mask_ratios = mask_ratios.cpu().numpy()
140
+ mask = compute_mask_indices(shape=[B, L],
141
+ padding_mask=None,
142
+ mask_prob=mask_ratios,
143
+ mask_length=self.mask_span,
144
+ mask_type="static",
145
+ mask_other=0.0,
146
+ min_masks=1,
147
+ no_overlap=False,
148
+ min_space=0,)
149
+ mask = mask.unsqueeze(1).expand_as(gt)
150
+ else:
151
+ mask = mae_mask_infer
152
+ mask = mask.expand_as(gt)
153
+ gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask]
154
+ return gt, mask.type_as(gt)
155
+
156
+ def forward(self, x, timesteps, context,
157
+ x_mask=None, context_mask=None, cls_token=None,
158
+ gt=None, mae_mask_infer=None):
159
+ mae_mask = torch.ones_like(x)
160
+ if self.mae:
161
+ if gt is not None:
162
+ B, D, L = gt.shape
163
+ mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(gt.device)
164
+ gt, mae_mask = self.random_masking(gt, mask_ratios, mae_mask_infer)
165
+ # apply mae only to the selected batches
166
+ if mae_mask_infer is None:
167
+ # determine mae batch
168
+ mae_batch = torch.rand(B) < self.mae_prob
169
+ gt[~mae_batch] = self.mask_embed.view(1, D, 1).expand_as(gt)[~mae_batch]
170
+ mae_mask[~mae_batch] = 1.0
171
+ else:
172
+ B, D, L = x.shape
173
+ gt = self.mask_embed.view(1, D, 1).expand_as(x)
174
+ x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1)
175
+
176
+ x = self.model(x=x, timesteps=timesteps, context=context,
177
+ x_mask=x_mask, context_mask=context_mask,
178
+ cls_token=cls_token)
179
+ # print(mae_mask[:, 0, :].sum(dim=-1))
180
+ return x, mae_mask
src/models/udit.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.checkpoint
4
+ import math
5
+ from .utils.modules import PatchEmbed, TimestepEmbedder
6
+ from .utils.modules import PE_wrapper, RMSNorm
7
+ from .blocks import DiTBlock, JointDiTBlock, FinalBlock
8
+
9
+
10
+ class UDiT(nn.Module):
11
+ def __init__(self,
12
+ img_size=224, patch_size=16, in_chans=3,
13
+ input_type='2d', out_chans=None,
14
+ embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
15
+ qkv_bias=False, qk_scale=None, qk_norm=None,
16
+ act_layer='gelu', norm_layer='layernorm',
17
+ context_norm=False,
18
+ use_checkpoint=False,
19
+ # time fusion ada or token
20
+ time_fusion='token',
21
+ ada_lora_rank=None, ada_lora_alpha=None,
22
+ cls_dim=None,
23
+ # max length is only used for concat
24
+ context_dim=768, context_fusion='concat',
25
+ context_max_length=128, context_pe_method='sinu',
26
+ pe_method='abs', rope_mode='none',
27
+ use_conv=True,
28
+ skip=True, skip_norm=True):
29
+ super().__init__()
30
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
31
+
32
+ # input
33
+ self.in_chans = in_chans
34
+ self.input_type = input_type
35
+ if self.input_type == '2d':
36
+ num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
37
+ elif self.input_type == '1d':
38
+ num_patches = img_size // patch_size
39
+ self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans,
40
+ embed_dim=embed_dim, input_type=input_type)
41
+ out_chans = in_chans if out_chans is None else out_chans
42
+ self.out_chans = out_chans
43
+
44
+ # position embedding
45
+ self.rope = rope_mode
46
+ self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method,
47
+ length=num_patches)
48
+
49
+ print(f'x position embedding: {pe_method}')
50
+ print(f'rope mode: {self.rope}')
51
+
52
+ # time embed
53
+ self.time_embed = TimestepEmbedder(embed_dim)
54
+ self.time_fusion = time_fusion
55
+ self.use_adanorm = False
56
+
57
+ # cls embed
58
+ if cls_dim is not None:
59
+ self.cls_embed = nn.Sequential(
60
+ nn.Linear(cls_dim, embed_dim, bias=True),
61
+ nn.SiLU(),
62
+ nn.Linear(embed_dim, embed_dim, bias=True),)
63
+ else:
64
+ self.cls_embed = None
65
+
66
+ # time fusion
67
+ if time_fusion == 'token':
68
+ # put token at the beginning of sequence
69
+ self.extras = 2 if self.cls_embed else 1
70
+ self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras)
71
+ elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']:
72
+ self.use_adanorm = True
73
+ # aviod repetitive silu for each adaln block
74
+ self.time_act = nn.SiLU()
75
+ self.extras = 0
76
+ self.time_ada_final = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
77
+ if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']:
78
+ # shared adaln
79
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
80
+ else:
81
+ self.time_ada = None
82
+ else:
83
+ raise NotImplementedError
84
+ print(f'time fusion mode: {self.time_fusion}')
85
+
86
+ # context
87
+ # use a simple projection
88
+ self.use_context = False
89
+ self.context_cross = False
90
+ self.context_max_length = context_max_length
91
+ self.context_fusion = 'none'
92
+ if context_dim is not None:
93
+ self.use_context = True
94
+ self.context_embed = nn.Sequential(
95
+ nn.Linear(context_dim, embed_dim, bias=True),
96
+ nn.SiLU(),
97
+ nn.Linear(embed_dim, embed_dim, bias=True),)
98
+ self.context_fusion = context_fusion
99
+ if context_fusion == 'concat' or context_fusion == 'joint':
100
+ self.extras += context_max_length
101
+ self.context_pe = PE_wrapper(dim=embed_dim,
102
+ method=context_pe_method,
103
+ length=context_max_length)
104
+ # no cross attention layers
105
+ context_dim = None
106
+ elif context_fusion == 'cross':
107
+ self.context_pe = PE_wrapper(dim=embed_dim,
108
+ method=context_pe_method,
109
+ length=context_max_length)
110
+ self.context_cross = True
111
+ context_dim = embed_dim
112
+ else:
113
+ raise NotImplementedError
114
+ print(f'context fusion mode: {context_fusion}')
115
+ print(f'context position embedding: {context_pe_method}')
116
+
117
+ if self.context_fusion == 'joint':
118
+ Block = JointDiTBlock
119
+ self.use_skip = skip[0]
120
+ else:
121
+ Block = DiTBlock
122
+ self.use_skip = skip
123
+
124
+ # norm layers
125
+ if norm_layer == 'layernorm':
126
+ norm_layer = nn.LayerNorm
127
+ elif norm_layer == 'rmsnorm':
128
+ norm_layer = RMSNorm
129
+ else:
130
+ raise NotImplementedError
131
+
132
+ print(f'use long skip connection: {skip}')
133
+ self.in_blocks = nn.ModuleList([
134
+ Block(
135
+ dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
136
+ mlp_ratio=mlp_ratio,
137
+ qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
138
+ act_layer=act_layer, norm_layer=norm_layer,
139
+ time_fusion=time_fusion,
140
+ ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
141
+ skip=False, skip_norm=False,
142
+ rope_mode=self.rope,
143
+ context_norm=context_norm,
144
+ use_checkpoint=use_checkpoint)
145
+ for _ in range(depth // 2)])
146
+
147
+ self.mid_block = Block(
148
+ dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
149
+ mlp_ratio=mlp_ratio,
150
+ qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
151
+ act_layer=act_layer, norm_layer=norm_layer,
152
+ time_fusion=time_fusion,
153
+ ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
154
+ skip=False, skip_norm=False,
155
+ rope_mode=self.rope,
156
+ context_norm=context_norm,
157
+ use_checkpoint=use_checkpoint)
158
+
159
+ self.out_blocks = nn.ModuleList([
160
+ Block(
161
+ dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
162
+ mlp_ratio=mlp_ratio,
163
+ qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
164
+ act_layer=act_layer, norm_layer=norm_layer,
165
+ time_fusion=time_fusion,
166
+ ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
167
+ skip=skip, skip_norm=skip_norm,
168
+ rope_mode=self.rope,
169
+ context_norm=context_norm,
170
+ use_checkpoint=use_checkpoint)
171
+ for _ in range(depth // 2)])
172
+
173
+ # FinalLayer block
174
+ self.use_conv = use_conv
175
+ self.final_block = FinalBlock(embed_dim=embed_dim,
176
+ patch_size=patch_size,
177
+ img_size=img_size,
178
+ in_chans=out_chans,
179
+ input_type=input_type,
180
+ norm_layer=norm_layer,
181
+ use_conv=use_conv,
182
+ use_adanorm=self.use_adanorm)
183
+ self.initialize_weights()
184
+
185
+ def _init_ada(self):
186
+ if self.time_fusion == 'ada':
187
+ nn.init.constant_(self.time_ada_final.weight, 0)
188
+ nn.init.constant_(self.time_ada_final.bias, 0)
189
+ for block in self.in_blocks:
190
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
191
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
192
+ nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
193
+ nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
194
+ for block in self.out_blocks:
195
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
196
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
197
+ elif self.time_fusion == 'ada_single':
198
+ nn.init.constant_(self.time_ada.weight, 0)
199
+ nn.init.constant_(self.time_ada.bias, 0)
200
+ nn.init.constant_(self.time_ada_final.weight, 0)
201
+ nn.init.constant_(self.time_ada_final.bias, 0)
202
+ elif self.time_fusion in ['ada_lora', 'ada_lora_bias']:
203
+ nn.init.constant_(self.time_ada.weight, 0)
204
+ nn.init.constant_(self.time_ada.bias, 0)
205
+ nn.init.constant_(self.time_ada_final.weight, 0)
206
+ nn.init.constant_(self.time_ada_final.bias, 0)
207
+ for block in self.in_blocks:
208
+ nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
209
+ a=math.sqrt(5))
210
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
211
+ nn.init.kaiming_uniform_(self.mid_block.adaln.lora_a.weight,
212
+ a=math.sqrt(5))
213
+ nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
214
+ for block in self.out_blocks:
215
+ nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
216
+ a=math.sqrt(5))
217
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
218
+
219
+ def initialize_weights(self):
220
+ # Basic init for all layers
221
+ def _basic_init(module):
222
+ if isinstance(module, nn.Linear):
223
+ torch.nn.init.xavier_uniform_(module.weight)
224
+ if module.bias is not None:
225
+ nn.init.constant_(module.bias, 0)
226
+ self.apply(_basic_init)
227
+
228
+ # init patch Conv like Linear
229
+ w = self.patch_embed.proj.weight.data
230
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
231
+ nn.init.constant_(self.patch_embed.proj.bias, 0)
232
+
233
+ # Zero-out AdaLN
234
+ if self.use_adanorm:
235
+ self._init_ada()
236
+
237
+ # Zero-out Cross Attention
238
+ if self.context_cross:
239
+ for block in self.in_blocks:
240
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
241
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
242
+ nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
243
+ nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
244
+ for block in self.out_blocks:
245
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
246
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
247
+
248
+ # Zero-out cls embedding
249
+ if self.cls_embed:
250
+ if self.use_adanorm:
251
+ nn.init.constant_(self.cls_embed[-1].weight, 0)
252
+ nn.init.constant_(self.cls_embed[-1].bias, 0)
253
+
254
+ # Zero-out Output
255
+ # might not zero-out this when using v-prediction
256
+ # it could be good when using noise-prediction
257
+ # nn.init.constant_(self.final_block.linear.weight, 0)
258
+ # nn.init.constant_(self.final_block.linear.bias, 0)
259
+ # if self.use_conv:
260
+ # nn.init.constant_(self.final_block.final_layer.weight.data, 0)
261
+ # nn.init.constant_(self.final_block.final_layer.bias, 0)
262
+
263
+ # init out Conv
264
+ if self.use_conv:
265
+ nn.init.xavier_uniform_(self.final_block.final_layer.weight)
266
+ nn.init.constant_(self.final_block.final_layer.bias, 0)
267
+
268
+ def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
269
+ assert context.shape[-2] == self.context_max_length
270
+ # Check if either x_mask or context_mask is provided
271
+ B = x.shape[0]
272
+ # Create default masks if they are not provided
273
+ if x_mask is None:
274
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
275
+ if context_mask is None:
276
+ context_mask = torch.ones(B, context.shape[-2],
277
+ device=context.device).bool()
278
+ # Concatenate the masks along the second dimension (dim=1)
279
+ x_mask = torch.cat([context_mask, x_mask], dim=1)
280
+ # Concatenate context and x along the second dimension (dim=1)
281
+ x = torch.cat((context, x), dim=1)
282
+ return x, x_mask
283
+
284
+ def forward(self, x, timesteps, context,
285
+ x_mask=None, context_mask=None,
286
+ cls_token=None
287
+ ):
288
+ # make it compatible with int time step during inference
289
+ if timesteps.dim() == 0:
290
+ timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
291
+
292
+ x = self.patch_embed(x)
293
+ x = self.x_pe(x)
294
+
295
+ B, L, D = x.shape
296
+
297
+ if self.use_context:
298
+ context_token = self.context_embed(context)
299
+ context_token = self.context_pe(context_token)
300
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
301
+ x, x_mask = self._concat_x_context(x=x, context=context_token,
302
+ x_mask=x_mask,
303
+ context_mask=context_mask)
304
+ context_token, context_mask = None, None
305
+ else:
306
+ context_token, context_mask = None, None
307
+
308
+ time_token = self.time_embed(timesteps)
309
+ if self.cls_embed:
310
+ cls_token = self.cls_embed(cls_token)
311
+ time_ada = None
312
+ time_ada_final = None
313
+ if self.use_adanorm:
314
+ if self.cls_embed:
315
+ time_token = time_token + cls_token
316
+ time_token = self.time_act(time_token)
317
+ time_ada_final = self.time_ada_final(time_token)
318
+ if self.time_ada is not None:
319
+ time_ada = self.time_ada(time_token)
320
+ else:
321
+ time_token = time_token.unsqueeze(dim=1)
322
+ if self.cls_embed:
323
+ cls_token = cls_token.unsqueeze(dim=1)
324
+ time_token = torch.cat([time_token, cls_token], dim=1)
325
+ time_token = self.time_pe(time_token)
326
+ x = torch.cat((time_token, x), dim=1)
327
+ if x_mask is not None:
328
+ x_mask = torch.cat(
329
+ [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(),
330
+ x_mask], dim=1)
331
+ time_token = None
332
+
333
+ skips = []
334
+ for blk in self.in_blocks:
335
+ x = blk(x=x, time_token=time_token, time_ada=time_ada,
336
+ skip=None, context=context_token,
337
+ x_mask=x_mask, context_mask=context_mask,
338
+ extras=self.extras)
339
+ if self.use_skip:
340
+ skips.append(x)
341
+
342
+ x = self.mid_block(x=x, time_token=time_token, time_ada=time_ada,
343
+ skip=None, context=context_token,
344
+ x_mask=x_mask, context_mask=context_mask,
345
+ extras=self.extras)
346
+
347
+ for blk in self.out_blocks:
348
+ skip = skips.pop() if self.use_skip else None
349
+ x = blk(x=x, time_token=time_token, time_ada=time_ada,
350
+ skip=skip, context=context_token,
351
+ x_mask=x_mask, context_mask=context_mask,
352
+ extras=self.extras)
353
+
354
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
355
+
356
+ return x
src/models/utils/.ipynb_checkpoints/__init__-checkpoint.py ADDED
File without changes
src/models/utils/.ipynb_checkpoints/attention-checkpoint.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint
5
+ import einops
6
+ from einops import rearrange, repeat
7
+ from inspect import isfunction
8
+ from .rotary import RotaryEmbedding
9
+ from .modules import RMSNorm
10
+
11
+
12
+ if hasattr(nn.functional, 'scaled_dot_product_attention'):
13
+ ATTENTION_MODE = 'flash'
14
+ else:
15
+ ATTENTION_MODE = 'math'
16
+ print(f'attention mode is {ATTENTION_MODE}')
17
+
18
+
19
+ def add_mask(sim, mask):
20
+ b, ndim = sim.shape[0], mask.ndim
21
+ if ndim == 3:
22
+ mask = rearrange(mask, "b n m -> b 1 n m")
23
+ if ndim == 2:
24
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
25
+ max_neg_value = -torch.finfo(sim.dtype).max
26
+ sim = sim.masked_fill(~mask, max_neg_value)
27
+ return sim
28
+
29
+
30
+ def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
31
+ def default(val, d):
32
+ return val if val is not None else (d() if isfunction(d) else d)
33
+ b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
34
+ q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
35
+ k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
36
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
37
+ return attn_mask
38
+
39
+
40
+ class Attention(nn.Module):
41
+ def __init__(self, dim, context_dim=None, num_heads=8,
42
+ qkv_bias=False, qk_scale=None, qk_norm=None,
43
+ attn_drop=0., proj_drop=0., rope_mode='none'):
44
+ super().__init__()
45
+ self.num_heads = num_heads
46
+ head_dim = dim // num_heads
47
+ self.scale = qk_scale or head_dim ** -0.5
48
+
49
+ if context_dim is None:
50
+ self.cross_attn = False
51
+ else:
52
+ self.cross_attn = True
53
+
54
+ context_dim = dim if context_dim is None else context_dim
55
+
56
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
57
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
58
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
59
+
60
+ if qk_norm is None:
61
+ self.norm_q = nn.Identity()
62
+ self.norm_k = nn.Identity()
63
+ elif qk_norm == 'layernorm':
64
+ self.norm_q = nn.LayerNorm(head_dim)
65
+ self.norm_k = nn.LayerNorm(head_dim)
66
+ elif qk_norm == 'rmsnorm':
67
+ self.norm_q = RMSNorm(head_dim)
68
+ self.norm_k = RMSNorm(head_dim)
69
+ else:
70
+ raise NotImplementedError
71
+
72
+ self.attn_drop_p = attn_drop
73
+ self.attn_drop = nn.Dropout(attn_drop)
74
+ self.proj = nn.Linear(dim, dim)
75
+ self.proj_drop = nn.Dropout(proj_drop)
76
+
77
+ if self.cross_attn:
78
+ assert rope_mode == 'none'
79
+ self.rope_mode = rope_mode
80
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
81
+ self.rotary = RotaryEmbedding(dim=head_dim)
82
+ elif self.rope_mode == 'dual':
83
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
84
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
85
+
86
+ def _rotary(self, q, k, extras):
87
+ if self.rope_mode == 'shared':
88
+ q, k = self.rotary(q=q, k=k)
89
+ elif self.rope_mode == 'x_only':
90
+ q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
91
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
92
+ q = torch.cat((q_c, q_x), dim=2)
93
+ k = torch.cat((k_c, k_x), dim=2)
94
+ elif self.rope_mode == 'dual':
95
+ q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
96
+ q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
97
+ q = torch.cat((q_c, q_x), dim=2)
98
+ k = torch.cat((k_c, k_x), dim=2)
99
+ elif self.rope_mode == 'none':
100
+ pass
101
+ else:
102
+ raise NotImplementedError
103
+ return q, k
104
+
105
+ def _attn(self, q, k, v, mask_binary):
106
+ if ATTENTION_MODE == 'flash':
107
+ x = F.scaled_dot_product_attention(q, k, v,
108
+ dropout_p=self.attn_drop_p,
109
+ attn_mask=mask_binary)
110
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
111
+ elif ATTENTION_MODE == 'math':
112
+ attn = (q @ k.transpose(-2, -1)) * self.scale
113
+ attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
114
+ attn = attn.softmax(dim=-1)
115
+ attn = self.attn_drop(attn)
116
+ x = (attn @ v).transpose(1, 2)
117
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
118
+ else:
119
+ raise NotImplementedError
120
+ return x
121
+
122
+ def forward(self, x, context=None, context_mask=None, extras=0):
123
+ B, L, C = x.shape
124
+ if context is None:
125
+ context = x
126
+
127
+ q = self.to_q(x)
128
+ k = self.to_k(context)
129
+ v = self.to_v(context)
130
+
131
+ if context_mask is not None:
132
+ mask_binary = create_mask(x.shape, context.shape,
133
+ x.device, None, context_mask)
134
+ else:
135
+ mask_binary = None
136
+
137
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
138
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
139
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)
140
+
141
+ q = self.norm_q(q)
142
+ k = self.norm_k(k)
143
+
144
+ q, k = self._rotary(q, k, extras)
145
+
146
+ x = self._attn(q, k, v, mask_binary)
147
+
148
+ x = self.proj(x)
149
+ x = self.proj_drop(x)
150
+ return x
151
+
152
+
153
+ class JointAttention(nn.Module):
154
+ def __init__(self, dim, num_heads=8,
155
+ qkv_bias=False, qk_scale=None, qk_norm=None,
156
+ attn_drop=0., proj_drop=0.,
157
+ rope_mode='none'):
158
+ super().__init__()
159
+ self.num_heads = num_heads
160
+ head_dim = dim // num_heads
161
+ self.scale = qk_scale or head_dim ** -0.5
162
+
163
+ self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(dim, qkv_bias)
164
+ self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(dim, qkv_bias)
165
+
166
+ self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim)
167
+ self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim)
168
+
169
+ self.attn_drop_p = attn_drop
170
+ self.attn_drop = nn.Dropout(attn_drop)
171
+
172
+ self.proj_x = nn.Linear(dim, dim)
173
+ self.proj_drop_x = nn.Dropout(proj_drop)
174
+
175
+ self.proj_c = nn.Linear(dim, dim)
176
+ self.proj_drop_c = nn.Dropout(proj_drop)
177
+
178
+ self.rope_mode = rope_mode
179
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
180
+ self.rotary = RotaryEmbedding(dim=head_dim)
181
+ elif self.rope_mode == 'dual':
182
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
183
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
184
+
185
+ def _make_qkv_layers(self, dim, qkv_bias):
186
+ return (nn.Linear(dim, dim, bias=qkv_bias),
187
+ nn.Linear(dim, dim, bias=qkv_bias),
188
+ nn.Linear(dim, dim, bias=qkv_bias))
189
+
190
+ def _make_norm_layers(self, qk_norm, head_dim):
191
+ if qk_norm is None:
192
+ norm_q = nn.Identity()
193
+ norm_k = nn.Identity()
194
+ elif qk_norm == 'layernorm':
195
+ norm_q = nn.LayerNorm(head_dim)
196
+ norm_k = nn.LayerNorm(head_dim)
197
+ elif qk_norm == 'rmsnorm':
198
+ norm_q = RMSNorm(head_dim)
199
+ norm_k = RMSNorm(head_dim)
200
+ else:
201
+ raise NotImplementedError
202
+ return norm_q, norm_k
203
+
204
+ def _rotary(self, q, k, extras):
205
+ if self.rope_mode == 'shared':
206
+ q, k = self.rotary(q=q, k=k)
207
+ elif self.rope_mode == 'x_only':
208
+ q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
209
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
210
+ q = torch.cat((q_c, q_x), dim=2)
211
+ k = torch.cat((k_c, k_x), dim=2)
212
+ elif self.rope_mode == 'dual':
213
+ q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
214
+ q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
215
+ q = torch.cat((q_c, q_x), dim=2)
216
+ k = torch.cat((k_c, k_x), dim=2)
217
+ elif self.rope_mode == 'none':
218
+ pass
219
+ else:
220
+ raise NotImplementedError
221
+ return q, k
222
+
223
+ def _attn(self, q, k, v, mask_binary):
224
+ if ATTENTION_MODE == 'flash':
225
+ x = F.scaled_dot_product_attention(q, k, v,
226
+ dropout_p=self.attn_drop_p,
227
+ attn_mask=mask_binary)
228
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
229
+ elif ATTENTION_MODE == 'math':
230
+ attn = (q @ k.transpose(-2, -1)) * self.scale
231
+ attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
232
+ attn = attn.softmax(dim=-1)
233
+ attn = self.attn_drop(attn)
234
+ x = (attn @ v).transpose(1, 2)
235
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
236
+ else:
237
+ raise NotImplementedError
238
+ return x
239
+
240
+ def _cat_mask(self, x, context, x_mask=None, context_mask=None):
241
+ B = x.shape[0]
242
+ if x_mask is None:
243
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
244
+ if context_mask is None:
245
+ context_mask = torch.ones(B, context.shape[-2], device=context.device).bool()
246
+ mask = torch.cat([context_mask, x_mask], dim=1)
247
+ return mask
248
+
249
+ def forward(self, x, context, x_mask=None, context_mask=None, extras=0):
250
+ B, Lx, C = x.shape
251
+ _, Lc, _ = context.shape
252
+ if x_mask is not None or context_mask is not None:
253
+ mask = self._cat_mask(x, context,
254
+ x_mask=x_mask,
255
+ context_mask=context_mask)
256
+ shape = [B, Lx+Lc, C]
257
+ mask_binary = create_mask(q_shape=shape, k_shape=shape,
258
+ device=x.device,
259
+ q_mask=None, k_mask=mask)
260
+ else:
261
+ mask_binary = None
262
+
263
+ qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x)
264
+ qc, kc, vc = self.to_qc(context), self.to_kc(context), self.to_vc(context)
265
+
266
+ qx, kx, vx = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D',
267
+ H=self.num_heads), [qx, kx, vx])
268
+ qc, kc, vc = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D',
269
+ H=self.num_heads), [qc, kc, vc])
270
+
271
+ qx, kx = self.norm_qx(qx), self.norm_kx(kx)
272
+ qc, kc = self.norm_qc(qc), self.norm_kc(kc)
273
+
274
+ q, k, v = (torch.cat([qc, qx], dim=2),
275
+ torch.cat([kc, kx], dim=2),
276
+ torch.cat([vc, vx], dim=2))
277
+
278
+ q, k = self._rotary(q, k, extras)
279
+
280
+ x = self._attn(q, k, v, mask_binary)
281
+
282
+ context, x = x[:, :Lc, :], x[:, Lc:, :]
283
+
284
+ x = self.proj_x(x)
285
+ x = self.proj_drop_x(x)
286
+
287
+ context = self.proj_c(context)
288
+ context = self.proj_drop_c(context)
289
+
290
+ return x, context
src/models/utils/.ipynb_checkpoints/modules-checkpoint.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint
5
+ from torch.cuda.amp import autocast
6
+ import math
7
+ import einops
8
+ from einops import rearrange, repeat
9
+ from inspect import isfunction
10
+ from .timm import trunc_normal_
11
+
12
+
13
+ # disable in checkpoint mode
14
+ # @torch.jit.script
15
+ def film_modulate(x, shift, scale):
16
+ return x * (1 + scale) + shift
17
+
18
+
19
+ def timestep_embedding(timesteps, dim, max_period=10000):
20
+ """
21
+ Create sinusoidal timestep embeddings.
22
+
23
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
24
+ These may be fractional.
25
+ :param dim: the dimension of the output.
26
+ :param max_period: controls the minimum frequency of the embeddings.
27
+ :return: an [N x dim] Tensor of positional embeddings.
28
+ """
29
+ half = dim // 2
30
+ freqs = torch.exp(
31
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
32
+ ).to(device=timesteps.device)
33
+ args = timesteps[:, None].float() * freqs[None]
34
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
35
+ if dim % 2:
36
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
37
+ return embedding
38
+
39
+
40
+ class TimestepEmbedder(nn.Module):
41
+ """
42
+ Embeds scalar timesteps into vector representations.
43
+ """
44
+
45
+ def __init__(self, hidden_size, frequency_embedding_size=256,
46
+ out_size=None):
47
+ super().__init__()
48
+ if out_size is None:
49
+ out_size = hidden_size
50
+ self.mlp = nn.Sequential(
51
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
52
+ nn.SiLU(),
53
+ nn.Linear(hidden_size, out_size, bias=True),
54
+ )
55
+ self.frequency_embedding_size = frequency_embedding_size
56
+
57
+ def forward(self, t):
58
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
59
+ self.mlp[0].weight.dtype)
60
+ t_emb = self.mlp(t_freq)
61
+ return t_emb
62
+
63
+
64
+ def patchify(imgs, patch_size, input_type='2d'):
65
+ if input_type == '2d':
66
+ x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
67
+ elif input_type == '1d':
68
+ x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size)
69
+ return x
70
+
71
+
72
+ def unpatchify(x, channels=3, input_type='2d', img_size=None):
73
+ if input_type == '2d':
74
+ patch_size = int((x.shape[2] // channels) ** 0.5)
75
+ # h = w = int(x.shape[1] ** .5)
76
+ h, w = img_size[0] // patch_size, img_size[1] // patch_size
77
+ assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
78
+ x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h,
79
+ p1=patch_size, p2=patch_size)
80
+ elif input_type == '1d':
81
+ patch_size = int((x.shape[2] // channels))
82
+ h = x.shape[1]
83
+ assert patch_size * channels == x.shape[2]
84
+ x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size)
85
+ return x
86
+
87
+
88
+ class PatchEmbed(nn.Module):
89
+ """
90
+ Image to Patch Embedding
91
+ """
92
+
93
+ def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'):
94
+ super().__init__()
95
+ self.patch_size = patch_size
96
+ self.input_type = input_type
97
+ if input_type == '2d':
98
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
99
+ elif input_type == '1d':
100
+ self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
101
+
102
+ def forward(self, x):
103
+ if self.input_type == '2d':
104
+ B, C, H, W = x.shape
105
+ assert H % self.patch_size == 0 and W % self.patch_size == 0
106
+ elif self.input_type == '1d':
107
+ B, C, H = x.shape
108
+ assert H % self.patch_size == 0
109
+
110
+ x = self.proj(x).flatten(2).transpose(1, 2)
111
+ return x
112
+
113
+
114
+ class PositionalConvEmbedding(nn.Module):
115
+ """
116
+ Relative positional embedding used in HuBERT
117
+ """
118
+
119
+ def __init__(self, dim=768, kernel_size=128, groups=16):
120
+ super().__init__()
121
+ self.conv = nn.Conv1d(
122
+ dim,
123
+ dim,
124
+ kernel_size=kernel_size,
125
+ padding=kernel_size // 2,
126
+ groups=groups,
127
+ bias=True
128
+ )
129
+ self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
130
+
131
+ def forward(self, x):
132
+ # B C T
133
+ x = self.conv(x)
134
+ x = F.gelu(x[:, :, :-1])
135
+ return x
136
+
137
+
138
+ class SinusoidalPositionalEncoding(nn.Module):
139
+ def __init__(self, dim, length):
140
+ super(SinusoidalPositionalEncoding, self).__init__()
141
+ self.length = length
142
+ self.dim = dim
143
+ self.register_buffer('pe', self._generate_positional_encoding(length, dim))
144
+
145
+ def _generate_positional_encoding(self, length, dim):
146
+ pe = torch.zeros(length, dim)
147
+ position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
148
+ div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
149
+
150
+ pe[:, 0::2] = torch.sin(position * div_term)
151
+ pe[:, 1::2] = torch.cos(position * div_term)
152
+
153
+ pe = pe.unsqueeze(0)
154
+ return pe
155
+
156
+ def forward(self, x):
157
+ x = x + self.pe[:, :x.size(1)]
158
+ return x
159
+
160
+
161
+ class PE_wrapper(nn.Module):
162
+ def __init__(self, dim=768, method='abs', length=None, **kwargs):
163
+ super().__init__()
164
+ self.method = method
165
+ if method == 'abs':
166
+ # init absolute pe like UViT
167
+ self.length = length
168
+ self.abs_pe = nn.Parameter(torch.zeros(1, length, dim))
169
+ trunc_normal_(self.abs_pe, std=.02)
170
+ elif method == 'conv':
171
+ self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs)
172
+ elif method == 'sinu':
173
+ self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length)
174
+ elif method == 'none':
175
+ # skip pe
176
+ self.id = nn.Identity()
177
+ else:
178
+ raise NotImplementedError
179
+
180
+ def forward(self, x):
181
+ if self.method == 'abs':
182
+ _, L, _ = x.shape
183
+ assert L <= self.length
184
+ x = x + self.abs_pe[:, :L, :]
185
+ elif self.method == 'conv':
186
+ x = x + self.conv_pe(x)
187
+ elif self.method == 'sinu':
188
+ x = self.sinu_pe(x)
189
+ elif self.method == 'none':
190
+ x = self.id(x)
191
+ else:
192
+ raise NotImplementedError
193
+ return x
194
+
195
+
196
+ class RMSNorm(torch.nn.Module):
197
+ def __init__(self, dim: int, eps: float = 1e-6):
198
+ """
199
+ Initialize the RMSNorm normalization layer.
200
+
201
+ Args:
202
+ dim (int): The dimension of the input tensor.
203
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
204
+
205
+ Attributes:
206
+ eps (float): A small value added to the denominator for numerical stability.
207
+ weight (nn.Parameter): Learnable scaling parameter.
208
+
209
+ """
210
+ super().__init__()
211
+ self.eps = eps
212
+ self.weight = nn.Parameter(torch.ones(dim))
213
+
214
+ def _norm(self, x):
215
+ """
216
+ Apply the RMSNorm normalization to the input tensor.
217
+
218
+ Args:
219
+ x (torch.Tensor): The input tensor.
220
+
221
+ Returns:
222
+ torch.Tensor: The normalized tensor.
223
+
224
+ """
225
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
226
+
227
+ def forward(self, x):
228
+ """
229
+ Forward pass through the RMSNorm layer.
230
+
231
+ Args:
232
+ x (torch.Tensor): The input tensor.
233
+
234
+ Returns:
235
+ torch.Tensor: The output tensor after applying RMSNorm.
236
+
237
+ """
238
+ output = self._norm(x.float()).type_as(x)
239
+ return output * self.weight
240
+
241
+
242
+ class GELU(nn.Module):
243
+
244
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none",
245
+ bias: bool = True):
246
+ super().__init__()
247
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
248
+ self.approximate = approximate
249
+
250
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
251
+ if gate.device.type != "mps":
252
+ return F.gelu(gate, approximate=self.approximate)
253
+ # mps: gelu is not implemented for float16
254
+ return F.gelu(gate.to(dtype=torch.float32),
255
+ approximate=self.approximate).to(dtype=gate.dtype)
256
+
257
+ def forward(self, hidden_states):
258
+ hidden_states = self.proj(hidden_states)
259
+ hidden_states = self.gelu(hidden_states)
260
+ return hidden_states
261
+
262
+
263
+ class GEGLU(nn.Module):
264
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
265
+ super().__init__()
266
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
267
+
268
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
269
+ if gate.device.type != "mps":
270
+ return F.gelu(gate)
271
+ # mps: gelu is not implemented for float16
272
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
273
+
274
+ def forward(self, hidden_states):
275
+ hidden_states = self.proj(hidden_states)
276
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
277
+ return hidden_states * self.gelu(gate)
278
+
279
+
280
+ class ApproximateGELU(nn.Module):
281
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
282
+ super().__init__()
283
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
284
+
285
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
286
+ x = self.proj(x)
287
+ return x * torch.sigmoid(1.702 * x)
288
+
289
+
290
+ # disable in checkpoint mode
291
+ # @torch.jit.script
292
+ def snake_beta(x, alpha, beta):
293
+ return x + beta * torch.sin(x * alpha).pow(2)
294
+
295
+
296
+ class Snake(nn.Module):
297
+ def __init__(self, dim_in, dim_out, bias,
298
+ alpha_trainable=True):
299
+ super().__init__()
300
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
301
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
302
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
303
+ self.alpha.requires_grad = alpha_trainable
304
+ self.beta.requires_grad = alpha_trainable
305
+
306
+ def forward(self, x):
307
+ x = self.proj(x)
308
+ x = snake_beta(x, self.alpha, self.beta)
309
+ return x
310
+
311
+
312
+ class GESnake(nn.Module):
313
+ def __init__(self, dim_in, dim_out, bias,
314
+ alpha_trainable=True):
315
+ super().__init__()
316
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
317
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
318
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
319
+ self.alpha.requires_grad = alpha_trainable
320
+ self.beta.requires_grad = alpha_trainable
321
+
322
+ def forward(self, x):
323
+ x = self.proj(x)
324
+ x, gate = x.chunk(2, dim=-1)
325
+ return x * snake_beta(gate, self.alpha, self.beta)
326
+
327
+
328
+ class FeedForward(nn.Module):
329
+ def __init__(
330
+ self,
331
+ dim,
332
+ dim_out=None,
333
+ mult=4,
334
+ dropout=0.0,
335
+ activation_fn="geglu",
336
+ final_dropout=False,
337
+ inner_dim=None,
338
+ bias=True,
339
+ ):
340
+ super().__init__()
341
+ if inner_dim is None:
342
+ inner_dim = int(dim * mult)
343
+ dim_out = dim_out if dim_out is not None else dim
344
+
345
+ if activation_fn == "gelu":
346
+ act_fn = GELU(dim, inner_dim, bias=bias)
347
+ elif activation_fn == "gelu-approximate":
348
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
349
+ elif activation_fn == "geglu":
350
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
351
+ elif activation_fn == "geglu-approximate":
352
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
353
+ elif activation_fn == "snake":
354
+ act_fn = Snake(dim, inner_dim, bias=bias)
355
+ elif activation_fn == "gesnake":
356
+ act_fn = GESnake(dim, inner_dim, bias=bias)
357
+ else:
358
+ raise NotImplementedError
359
+
360
+ self.net = nn.ModuleList([])
361
+ # project in
362
+ self.net.append(act_fn)
363
+ # project dropout
364
+ self.net.append(nn.Dropout(dropout))
365
+ # project out
366
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
367
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
368
+ if final_dropout:
369
+ self.net.append(nn.Dropout(dropout))
370
+
371
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
372
+ for module in self.net:
373
+ hidden_states = module(hidden_states)
374
+ return hidden_states
src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ "this rope is faster than llama rope with jit script"
4
+
5
+
6
+ def rotate_half(x):
7
+ x1, x2 = x.chunk(2, dim=-1)
8
+ return torch.cat((-x2, x1), dim=-1)
9
+
10
+
11
+ # disable in checkpoint mode
12
+ # @torch.jit.script
13
+ def apply_rotary_pos_emb(x, cos, sin):
14
+ # NOTE: This could probably be moved to Triton
15
+ # Handle a possible sequence length mismatch in between q and k
16
+ cos = cos[:, :, : x.shape[-2], :]
17
+ sin = sin[:, :, : x.shape[-2], :]
18
+ return (x * cos) + (rotate_half(x) * sin)
19
+
20
+
21
+ class RotaryEmbedding(torch.nn.Module):
22
+ """
23
+ The rotary position embeddings from RoFormer_ (Su et. al).
24
+ A crucial insight from the method is that the query and keys are
25
+ transformed by rotation matrices which depend on the relative positions.
26
+
27
+ Other implementations are available in the Rotary Transformer repo_ and in
28
+ GPT-NeoX_, GPT-NeoX was an inspiration
29
+
30
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
31
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
32
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
33
+
34
+
35
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
36
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
37
+ """
38
+
39
+ def __init__(self, dim: int):
40
+ super().__init__()
41
+ # Generate and save the inverse frequency buffer (non trainable)
42
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
43
+ self.register_buffer("inv_freq", inv_freq)
44
+ self._seq_len_cached = None
45
+ self._cos_cached = None
46
+ self._sin_cached = None
47
+
48
+ def _update_cos_sin_tables(self, x, seq_dimension=-2):
49
+ # expect input: B, H, L, D
50
+ seq_len = x.shape[seq_dimension]
51
+
52
+ # Reset the tables if the sequence length has changed,
53
+ # or if we're on a new device (possibly due to tracing for instance)
54
+ # also make sure dtype wont change
55
+ if (
56
+ seq_len != self._seq_len_cached
57
+ or self._cos_cached.device != x.device
58
+ or self._cos_cached.dtype != x.dtype
59
+ ):
60
+ self._seq_len_cached = seq_len
61
+ t = torch.arange(
62
+ x.shape[seq_dimension], device=x.device, dtype=torch.float32
63
+ )
64
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
65
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
66
+
67
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
68
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
69
+
70
+ return self._cos_cached, self._sin_cached
71
+
72
+ def forward(self, q, k):
73
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
74
+ q.float(), seq_dimension=-2
75
+ )
76
+ if k is not None:
77
+ return (
78
+ apply_rotary_pos_emb(q.float(),
79
+ self._cos_cached,
80
+ self._sin_cached).type_as(q),
81
+ apply_rotary_pos_emb(k.float(),
82
+ self._cos_cached,
83
+ self._sin_cached).type_as(k),
84
+ )
85
+ else:
86
+ return (
87
+ apply_rotary_pos_emb(q.float(),
88
+ self._cos_cached,
89
+ self._sin_cached).type_as(q),
90
+ None
91
+ )
src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from typing import Optional, Tuple
4
+
5
+
6
+ def compute_mask_indices(
7
+ shape: Tuple[int, int],
8
+ padding_mask: Optional[torch.Tensor],
9
+ mask_prob: float,
10
+ mask_length: int,
11
+ mask_type: str = "static",
12
+ mask_other: float = 0.0,
13
+ min_masks: int = 0,
14
+ no_overlap: bool = False,
15
+ min_space: int = 0,
16
+ ) -> np.ndarray:
17
+ """
18
+ Computes random mask spans for a given shape
19
+
20
+ Args:
21
+ shape: the the shape for which to compute masks.
22
+ should be of size 2 where first element is batch size and 2nd is timesteps
23
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
24
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
25
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
26
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
27
+ mask_type: how to compute mask lengths
28
+ static = fixed size
29
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
30
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
31
+ poisson = sample from possion distribution with lambda = mask length
32
+ min_masks: minimum number of masked spans
33
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
34
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
35
+ """
36
+
37
+ bsz, all_sz = shape
38
+ mask = np.full((bsz, all_sz), False)
39
+
40
+ # Convert mask_prob to a NumPy array
41
+ mask_prob = np.array(mask_prob)
42
+
43
+ # Calculate all_num_mask for each element in the batch
44
+ all_num_mask = np.floor(mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)).astype(int)
45
+
46
+ # Apply the max operation with min_masks for each element
47
+ all_num_mask = np.maximum(min_masks, all_num_mask)
48
+
49
+ mask_idcs = []
50
+ for i in range(bsz):
51
+ if padding_mask is not None:
52
+ sz = all_sz - padding_mask[i].long().sum().item()
53
+ num_mask = int(
54
+ # add a random number for probabilistic rounding
55
+ mask_prob * sz / float(mask_length)
56
+ + np.random.rand()
57
+ )
58
+ num_mask = max(min_masks, num_mask)
59
+ else:
60
+ sz = all_sz
61
+ num_mask = all_num_mask[i]
62
+
63
+ if mask_type == "static":
64
+ lengths = np.full(num_mask, mask_length)
65
+ elif mask_type == "uniform":
66
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
67
+ elif mask_type == "normal":
68
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
69
+ lengths = [max(1, int(round(x))) for x in lengths]
70
+ elif mask_type == "poisson":
71
+ lengths = np.random.poisson(mask_length, size=num_mask)
72
+ lengths = [int(round(x)) for x in lengths]
73
+ else:
74
+ raise Exception("unknown mask selection " + mask_type)
75
+
76
+ if sum(lengths) == 0:
77
+ lengths[0] = min(mask_length, sz - 1)
78
+
79
+ if no_overlap:
80
+ mask_idc = []
81
+
82
+ def arrange(s, e, length, keep_length):
83
+ span_start = np.random.randint(s, e - length)
84
+ mask_idc.extend(span_start + i for i in range(length))
85
+
86
+ new_parts = []
87
+ if span_start - s - min_space >= keep_length:
88
+ new_parts.append((s, span_start - min_space + 1))
89
+ if e - span_start - keep_length - min_space > keep_length:
90
+ new_parts.append((span_start + length + min_space, e))
91
+ return new_parts
92
+
93
+ parts = [(0, sz)]
94
+ min_length = min(lengths)
95
+ for length in sorted(lengths, reverse=True):
96
+ lens = np.fromiter(
97
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
98
+ np.int,
99
+ )
100
+ l_sum = np.sum(lens)
101
+ if l_sum == 0:
102
+ break
103
+ probs = lens / np.sum(lens)
104
+ c = np.random.choice(len(parts), p=probs)
105
+ s, e = parts.pop(c)
106
+ parts.extend(arrange(s, e, length, min_length))
107
+ mask_idc = np.asarray(mask_idc)
108
+ else:
109
+ min_len = min(lengths)
110
+ if sz - min_len <= num_mask:
111
+ min_len = sz - num_mask - 1
112
+
113
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
114
+
115
+ mask_idc = np.asarray(
116
+ [
117
+ mask_idc[j] + offset
118
+ for j in range(len(mask_idc))
119
+ for offset in range(lengths[j])
120
+ ]
121
+ )
122
+
123
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
124
+ # min_len = min([len(m) for m in mask_idcs])
125
+ for i, mask_idc in enumerate(mask_idcs):
126
+ # if len(mask_idc) > min_len:
127
+ # mask_idc = np.random.choice(mask_idc, min_len, replace=False)
128
+ mask[i, mask_idc] = True
129
+
130
+ return torch.tensor(mask)
131
+
132
+
133
+ if __name__ == '__main__':
134
+ mask = compute_mask_indices(
135
+ shape=[4, 500],
136
+ padding_mask=None,
137
+ mask_prob=[0.65, 0.5, 0.65, 0.65],
138
+ mask_length=10,
139
+ mask_type="static",
140
+ mask_other=0.0,
141
+ min_masks=1,
142
+ no_overlap=False,
143
+ min_space=0,
144
+ )
145
+ print(mask)
146
+ print(mask.sum(dim=1))
src/models/utils/.ipynb_checkpoints/timm-checkpoint.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code from timm 0.3.2
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ import warnings
6
+
7
+
8
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
9
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
10
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11
+ def norm_cdf(x):
12
+ # Computes standard normal cumulative distribution function
13
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
14
+
15
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
16
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
17
+ "The distribution of values may be incorrect.",
18
+ stacklevel=2)
19
+
20
+ with torch.no_grad():
21
+ # Values are generated by using a truncated uniform distribution and
22
+ # then using the inverse CDF for the normal distribution.
23
+ # Get upper and lower cdf values
24
+ l = norm_cdf((a - mean) / std)
25
+ u = norm_cdf((b - mean) / std)
26
+
27
+ # Uniformly fill tensor with values from [l, u], then translate to
28
+ # [2l-1, 2u-1].
29
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
30
+
31
+ # Use inverse cdf transform for normal distribution to get truncated
32
+ # standard normal
33
+ tensor.erfinv_()
34
+
35
+ # Transform to proper mean, std
36
+ tensor.mul_(std * math.sqrt(2.))
37
+ tensor.add_(mean)
38
+
39
+ # Clamp to ensure it's in the proper range
40
+ tensor.clamp_(min=a, max=b)
41
+ return tensor
42
+
43
+
44
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
45
+ # type: (Tensor, float, float, float, float) -> Tensor
46
+ r"""Fills the input Tensor with values drawn from a truncated
47
+ normal distribution. The values are effectively drawn from the
48
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
49
+ with values outside :math:`[a, b]` redrawn until they are within
50
+ the bounds. The method used for generating the random values works
51
+ best when :math:`a \leq \text{mean} \leq b`.
52
+ Args:
53
+ tensor: an n-dimensional `torch.Tensor`
54
+ mean: the mean of the normal distribution
55
+ std: the standard deviation of the normal distribution
56
+ a: the minimum cutoff value
57
+ b: the maximum cutoff value
58
+ Examples:
59
+ >>> w = torch.empty(3, 5)
60
+ >>> nn.init.trunc_normal_(w)
61
+ """
62
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
63
+
64
+
65
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
66
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
67
+
68
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
69
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
70
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
71
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
72
+ 'survival rate' as the argument.
73
+
74
+ """
75
+ if drop_prob == 0. or not training:
76
+ return x
77
+ keep_prob = 1 - drop_prob
78
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
79
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
80
+ random_tensor.floor_() # binarize
81
+ output = x.div(keep_prob) * random_tensor
82
+ return output
83
+
84
+
85
+ class DropPath(nn.Module):
86
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
87
+ """
88
+
89
+ def __init__(self, drop_prob=None):
90
+ super(DropPath, self).__init__()
91
+ self.drop_prob = drop_prob
92
+
93
+ def forward(self, x):
94
+ return drop_path(x, self.drop_prob, self.training)
95
+
96
+
97
+ class Mlp(nn.Module):
98
+ def __init__(self, in_features, hidden_features=None, out_features=None,
99
+ act_layer=nn.GELU, drop=0.):
100
+ super().__init__()
101
+ out_features = out_features or in_features
102
+ hidden_features = hidden_features or in_features
103
+ self.fc1 = nn.Linear(in_features, hidden_features)
104
+ self.act = act_layer()
105
+ self.fc2 = nn.Linear(hidden_features, out_features)
106
+ self.drop = nn.Dropout(drop)
107
+
108
+ def forward(self, x):
109
+ x = self.fc1(x)
110
+ x = self.act(x)
111
+ x = self.drop(x)
112
+ x = self.fc2(x)
113
+ x = self.drop(x)
114
+ return x
src/models/utils/__init__.py ADDED
File without changes
src/models/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (178 Bytes). View file
 
src/models/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (177 Bytes). View file
 
src/models/utils/__pycache__/attention.cpython-310.pyc ADDED
Binary file (7.61 kB). View file
 
src/models/utils/__pycache__/attention.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
src/models/utils/__pycache__/modules.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
src/models/utils/__pycache__/modules.cpython-311.pyc ADDED
Binary file (24 kB). View file
 
src/models/utils/__pycache__/rotary.cpython-310.pyc ADDED
Binary file (2.81 kB). View file
 
src/models/utils/__pycache__/rotary.cpython-311.pyc ADDED
Binary file (4.99 kB). View file
 
src/models/utils/__pycache__/span_mask.cpython-310.pyc ADDED
Binary file (4.75 kB). View file
 
src/models/utils/__pycache__/span_mask.cpython-311.pyc ADDED
Binary file (8.51 kB). View file
 
src/models/utils/__pycache__/timm.cpython-310.pyc ADDED
Binary file (4.22 kB). View file
 
src/models/utils/__pycache__/timm.cpython-311.pyc ADDED
Binary file (6.46 kB). View file
 
src/models/utils/attention.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint
5
+ import einops
6
+ from einops import rearrange, repeat
7
+ from inspect import isfunction
8
+ from .rotary import RotaryEmbedding
9
+ from .modules import RMSNorm
10
+
11
+
12
+ if hasattr(nn.functional, 'scaled_dot_product_attention'):
13
+ ATTENTION_MODE = 'flash'
14
+ else:
15
+ ATTENTION_MODE = 'math'
16
+ print(f'attention mode is {ATTENTION_MODE}')
17
+
18
+
19
+ def add_mask(sim, mask):
20
+ b, ndim = sim.shape[0], mask.ndim
21
+ if ndim == 3:
22
+ mask = rearrange(mask, "b n m -> b 1 n m")
23
+ if ndim == 2:
24
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
25
+ max_neg_value = -torch.finfo(sim.dtype).max
26
+ sim = sim.masked_fill(~mask, max_neg_value)
27
+ return sim
28
+
29
+
30
+ def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
31
+ def default(val, d):
32
+ return val if val is not None else (d() if isfunction(d) else d)
33
+ b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
34
+ q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
35
+ k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
36
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
37
+ return attn_mask
38
+
39
+
40
+ class Attention(nn.Module):
41
+ def __init__(self, dim, context_dim=None, num_heads=8,
42
+ qkv_bias=False, qk_scale=None, qk_norm=None,
43
+ attn_drop=0., proj_drop=0., rope_mode='none'):
44
+ super().__init__()
45
+ self.num_heads = num_heads
46
+ head_dim = dim // num_heads
47
+ self.scale = qk_scale or head_dim ** -0.5
48
+
49
+ if context_dim is None:
50
+ self.cross_attn = False
51
+ else:
52
+ self.cross_attn = True
53
+
54
+ context_dim = dim if context_dim is None else context_dim
55
+
56
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
57
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
58
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
59
+
60
+ if qk_norm is None:
61
+ self.norm_q = nn.Identity()
62
+ self.norm_k = nn.Identity()
63
+ elif qk_norm == 'layernorm':
64
+ self.norm_q = nn.LayerNorm(head_dim)
65
+ self.norm_k = nn.LayerNorm(head_dim)
66
+ elif qk_norm == 'rmsnorm':
67
+ self.norm_q = RMSNorm(head_dim)
68
+ self.norm_k = RMSNorm(head_dim)
69
+ else:
70
+ raise NotImplementedError
71
+
72
+ self.attn_drop_p = attn_drop
73
+ self.attn_drop = nn.Dropout(attn_drop)
74
+ self.proj = nn.Linear(dim, dim)
75
+ self.proj_drop = nn.Dropout(proj_drop)
76
+
77
+ if self.cross_attn:
78
+ assert rope_mode == 'none'
79
+ self.rope_mode = rope_mode
80
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
81
+ self.rotary = RotaryEmbedding(dim=head_dim)
82
+ elif self.rope_mode == 'dual':
83
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
84
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
85
+
86
+ def _rotary(self, q, k, extras):
87
+ if self.rope_mode == 'shared':
88
+ q, k = self.rotary(q=q, k=k)
89
+ elif self.rope_mode == 'x_only':
90
+ q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
91
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
92
+ q = torch.cat((q_c, q_x), dim=2)
93
+ k = torch.cat((k_c, k_x), dim=2)
94
+ elif self.rope_mode == 'dual':
95
+ q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
96
+ q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
97
+ q = torch.cat((q_c, q_x), dim=2)
98
+ k = torch.cat((k_c, k_x), dim=2)
99
+ elif self.rope_mode == 'none':
100
+ pass
101
+ else:
102
+ raise NotImplementedError
103
+ return q, k
104
+
105
+ def _attn(self, q, k, v, mask_binary):
106
+ if ATTENTION_MODE == 'flash':
107
+ x = F.scaled_dot_product_attention(q, k, v,
108
+ dropout_p=self.attn_drop_p,
109
+ attn_mask=mask_binary)
110
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
111
+ elif ATTENTION_MODE == 'math':
112
+ attn = (q @ k.transpose(-2, -1)) * self.scale
113
+ attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
114
+ attn = attn.softmax(dim=-1)
115
+ attn = self.attn_drop(attn)
116
+ x = (attn @ v).transpose(1, 2)
117
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
118
+ else:
119
+ raise NotImplementedError
120
+ return x
121
+
122
+ def forward(self, x, context=None, context_mask=None, extras=0):
123
+ B, L, C = x.shape
124
+ if context is None:
125
+ context = x
126
+
127
+ q = self.to_q(x)
128
+ k = self.to_k(context)
129
+ v = self.to_v(context)
130
+
131
+ if context_mask is not None:
132
+ mask_binary = create_mask(x.shape, context.shape,
133
+ x.device, None, context_mask)
134
+ else:
135
+ mask_binary = None
136
+
137
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
138
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
139
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)
140
+
141
+ q = self.norm_q(q)
142
+ k = self.norm_k(k)
143
+
144
+ q, k = self._rotary(q, k, extras)
145
+
146
+ x = self._attn(q, k, v, mask_binary)
147
+
148
+ x = self.proj(x)
149
+ x = self.proj_drop(x)
150
+ return x
151
+
152
+
153
+ class JointAttention(nn.Module):
154
+ def __init__(self, dim, num_heads=8,
155
+ qkv_bias=False, qk_scale=None, qk_norm=None,
156
+ attn_drop=0., proj_drop=0.,
157
+ rope_mode='none'):
158
+ super().__init__()
159
+ self.num_heads = num_heads
160
+ head_dim = dim // num_heads
161
+ self.scale = qk_scale or head_dim ** -0.5
162
+
163
+ self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(dim, qkv_bias)
164
+ self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(dim, qkv_bias)
165
+
166
+ self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim)
167
+ self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim)
168
+
169
+ self.attn_drop_p = attn_drop
170
+ self.attn_drop = nn.Dropout(attn_drop)
171
+
172
+ self.proj_x = nn.Linear(dim, dim)
173
+ self.proj_drop_x = nn.Dropout(proj_drop)
174
+
175
+ self.proj_c = nn.Linear(dim, dim)
176
+ self.proj_drop_c = nn.Dropout(proj_drop)
177
+
178
+ self.rope_mode = rope_mode
179
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
180
+ self.rotary = RotaryEmbedding(dim=head_dim)
181
+ elif self.rope_mode == 'dual':
182
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
183
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
184
+
185
+ def _make_qkv_layers(self, dim, qkv_bias):
186
+ return (nn.Linear(dim, dim, bias=qkv_bias),
187
+ nn.Linear(dim, dim, bias=qkv_bias),
188
+ nn.Linear(dim, dim, bias=qkv_bias))
189
+
190
+ def _make_norm_layers(self, qk_norm, head_dim):
191
+ if qk_norm is None:
192
+ norm_q = nn.Identity()
193
+ norm_k = nn.Identity()
194
+ elif qk_norm == 'layernorm':
195
+ norm_q = nn.LayerNorm(head_dim)
196
+ norm_k = nn.LayerNorm(head_dim)
197
+ elif qk_norm == 'rmsnorm':
198
+ norm_q = RMSNorm(head_dim)
199
+ norm_k = RMSNorm(head_dim)
200
+ else:
201
+ raise NotImplementedError
202
+ return norm_q, norm_k
203
+
204
+ def _rotary(self, q, k, extras):
205
+ if self.rope_mode == 'shared':
206
+ q, k = self.rotary(q=q, k=k)
207
+ elif self.rope_mode == 'x_only':
208
+ q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
209
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
210
+ q = torch.cat((q_c, q_x), dim=2)
211
+ k = torch.cat((k_c, k_x), dim=2)
212
+ elif self.rope_mode == 'dual':
213
+ q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
214
+ q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
215
+ q = torch.cat((q_c, q_x), dim=2)
216
+ k = torch.cat((k_c, k_x), dim=2)
217
+ elif self.rope_mode == 'none':
218
+ pass
219
+ else:
220
+ raise NotImplementedError
221
+ return q, k
222
+
223
+ def _attn(self, q, k, v, mask_binary):
224
+ if ATTENTION_MODE == 'flash':
225
+ x = F.scaled_dot_product_attention(q, k, v,
226
+ dropout_p=self.attn_drop_p,
227
+ attn_mask=mask_binary)
228
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
229
+ elif ATTENTION_MODE == 'math':
230
+ attn = (q @ k.transpose(-2, -1)) * self.scale
231
+ attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
232
+ attn = attn.softmax(dim=-1)
233
+ attn = self.attn_drop(attn)
234
+ x = (attn @ v).transpose(1, 2)
235
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
236
+ else:
237
+ raise NotImplementedError
238
+ return x
239
+
240
+ def _cat_mask(self, x, context, x_mask=None, context_mask=None):
241
+ B = x.shape[0]
242
+ if x_mask is None:
243
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
244
+ if context_mask is None:
245
+ context_mask = torch.ones(B, context.shape[-2], device=context.device).bool()
246
+ mask = torch.cat([context_mask, x_mask], dim=1)
247
+ return mask
248
+
249
+ def forward(self, x, context, x_mask=None, context_mask=None, extras=0):
250
+ B, Lx, C = x.shape
251
+ _, Lc, _ = context.shape
252
+ if x_mask is not None or context_mask is not None:
253
+ mask = self._cat_mask(x, context,
254
+ x_mask=x_mask,
255
+ context_mask=context_mask)
256
+ shape = [B, Lx+Lc, C]
257
+ mask_binary = create_mask(q_shape=shape, k_shape=shape,
258
+ device=x.device,
259
+ q_mask=None, k_mask=mask)
260
+ else:
261
+ mask_binary = None
262
+
263
+ qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x)
264
+ qc, kc, vc = self.to_qc(context), self.to_kc(context), self.to_vc(context)
265
+
266
+ qx, kx, vx = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D',
267
+ H=self.num_heads), [qx, kx, vx])
268
+ qc, kc, vc = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D',
269
+ H=self.num_heads), [qc, kc, vc])
270
+
271
+ qx, kx = self.norm_qx(qx), self.norm_kx(kx)
272
+ qc, kc = self.norm_qc(qc), self.norm_kc(kc)
273
+
274
+ q, k, v = (torch.cat([qc, qx], dim=2),
275
+ torch.cat([kc, kx], dim=2),
276
+ torch.cat([vc, vx], dim=2))
277
+
278
+ q, k = self._rotary(q, k, extras)
279
+
280
+ x = self._attn(q, k, v, mask_binary)
281
+
282
+ context, x = x[:, :Lc, :], x[:, Lc:, :]
283
+
284
+ x = self.proj_x(x)
285
+ x = self.proj_drop_x(x)
286
+
287
+ context = self.proj_c(context)
288
+ context = self.proj_drop_c(context)
289
+
290
+ return x, context
src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.checkpoint
4
+ import einops
5
+ from einops import rearrange, repeat
6
+ from inspect import isfunction
7
+ from .rotary import RotaryEmbedding
8
+
9
+ if hasattr(nn.functional, 'scaled_dot_product_attention'):
10
+ ATTENTION_MODE = 'flash'
11
+ else:
12
+ ATTENTION_MODE = 'math'
13
+ print(f'attention mode is {ATTENTION_MODE}')
14
+
15
+
16
+ def add_mask(sim, mask):
17
+ b, ndim = sim.shape[0], mask.ndim
18
+ if ndim == 3:
19
+ mask = rearrange(mask, "b n m -> b 1 n m")
20
+ if ndim == 2:
21
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
22
+ max_neg_value = -torch.finfo(sim.dtype).max
23
+ sim = sim.masked_fill(~mask, max_neg_value)
24
+ return sim
25
+
26
+
27
+ def create_mask(q, k, q_mask=None, k_mask=None):
28
+ def default(val, d):
29
+ return val if val is not None else (d() if isfunction(d) else d)
30
+
31
+ b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
32
+ q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
33
+ k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
34
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
35
+ return attn_mask
36
+
37
+
38
+ class Attention(nn.Module):
39
+ def __init__(self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None,
40
+ attn_drop=0., proj_drop=0., use_rope=False):
41
+ super().__init__()
42
+ self.num_heads = num_heads
43
+ head_dim = dim // num_heads
44
+ self.scale = qk_scale or head_dim ** -0.5
45
+
46
+ context_dim = dim if context_dim is None else context_dim
47
+
48
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
49
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
50
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
51
+ self.attn_drop_p = attn_drop
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ self.use_rope = use_rope
57
+ if self.use_rope:
58
+ self.rotary = RotaryEmbedding(dim=head_dim)
59
+
60
+ def forward(self, x, context=None, context_mask=None):
61
+ B, L, C = x.shape
62
+ q = self.to_q(x)
63
+ if context is None:
64
+ context = x
65
+ else:
66
+ assert self.use_rope is False
67
+
68
+ k = self.to_k(context)
69
+ v = self.to_v(context)
70
+
71
+ if context_mask is not None:
72
+ mask_binary = create_mask(x, context, None, context_mask)
73
+ else:
74
+ mask_binary = None
75
+
76
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads).float()
77
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads).float()
78
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads).float()
79
+
80
+ if self.use_rope:
81
+ q, k = self.rotary(q=q, k=k)
82
+
83
+ if ATTENTION_MODE == 'flash':
84
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v,
85
+ dropout_p=self.attn_drop_p,
86
+ attn_mask=mask_binary)
87
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
88
+ elif ATTENTION_MODE == 'math':
89
+ attn = (q @ k.transpose(-2, -1)) * self.scale
90
+ attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
91
+ attn = attn.softmax(dim=-1)
92
+ attn = self.attn_drop(attn)
93
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
94
+ else:
95
+ raise NotImplementedError
96
+
97
+ x = self.proj(x)
98
+ x = self.proj_drop(x)
99
+ return x
src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+ from rotary import RotaryEmbedding
4
+ import time
5
+
6
+
7
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
8
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
9
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
10
+ freqs = torch.outer(t, freqs)
11
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
12
+ return freqs_cis
13
+
14
+
15
+ def reshape_for_broadcast(freqs_cis: torch.Tensor,
16
+ x: torch.Tensor,):
17
+ ndim = x.ndim
18
+ assert 0 <= 1 < ndim
19
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
20
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
21
+ return freqs_cis.view(*shape)
22
+
23
+
24
+ def compute_rope(q, freqs_cis):
25
+ return q * freqs_cis
26
+
27
+
28
+ def apply_rotary_emb(
29
+ xq: torch.Tensor,
30
+ xk: torch.Tensor,
31
+ freqs_cis: torch.Tensor,
32
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
33
+ # xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
34
+ # xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
35
+ xq1, xq2 = xq.chunk(2, dim=-1)
36
+ xq_ = torch.view_as_complex(torch.stack((xq1, xq2), dim=-1).float())
37
+
38
+ xk1, xk2 = xk.chunk(2, dim=-1)
39
+ xk_ = torch.view_as_complex(torch.stack((xk1, xk2), dim=-1).float())
40
+
41
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
42
+ xq_out = torch.view_as_real(compute_rope(xq_, freqs_cis)).flatten(3)
43
+ xk_out = torch.view_as_real(compute_rope(xk_, freqs_cis)).flatten(3)
44
+ return xq_out.type_as(xq), xk_out.type_as(xk)
45
+
46
+
47
+ if __name__ == '__main__':
48
+ # Move data to CUDA
49
+ freq_cis = precompute_freqs_cis(4, 5).cuda()
50
+ x = torch.rand(1, 5, 1, 4).cuda()
51
+ y = torch.rand(1, 5, 1, 4).cuda()
52
+
53
+ # First method
54
+ start_time = time.time()
55
+ for _ in range(20000):
56
+ x1, y1 = apply_rotary_emb(x, y, freq_cis)
57
+ end_time = time.time()
58
+ print(f"Method 1 time cost: {end_time - start_time} seconds")
59
+
60
+ # Prepare data for the second method
61
+ x = x.permute(0, 2, 1, 3)
62
+ y = y.permute(0, 2, 1, 3)
63
+ rope = RotaryEmbedding(4).cuda()
64
+
65
+ # Second method
66
+ start_time = time.time()
67
+ for _ in range(20000):
68
+ x2, y2 = rope(x, y)
69
+ end_time = time.time()
70
+ print(f"Method 2 time cost: {end_time - start_time} seconds")
71
+
72
+ # Print the results
73
+ print(x1)
74
+ print(x2)
src/models/utils/bk/__pycache__/rotary.cpython-311.pyc ADDED
Binary file (4.8 kB). View file
 
src/models/utils/bk/attention.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.checkpoint
4
+ import einops
5
+ from einops import rearrange, repeat
6
+ from inspect import isfunction
7
+ from .rotary import RotaryEmbedding
8
+
9
+ if hasattr(nn.functional, 'scaled_dot_product_attention'):
10
+ ATTENTION_MODE = 'flash'
11
+ else:
12
+ ATTENTION_MODE = 'math'
13
+ print(f'attention mode is {ATTENTION_MODE}')
14
+
15
+
16
+ def add_mask(sim, mask):
17
+ b, ndim = sim.shape[0], mask.ndim
18
+ if ndim == 3:
19
+ mask = rearrange(mask, "b n m -> b 1 n m")
20
+ if ndim == 2:
21
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
22
+ max_neg_value = -torch.finfo(sim.dtype).max
23
+ sim = sim.masked_fill(~mask, max_neg_value)
24
+ return sim
25
+
26
+
27
+ def create_mask(q, k, q_mask=None, k_mask=None):
28
+ def default(val, d):
29
+ return val if val is not None else (d() if isfunction(d) else d)
30
+
31
+ b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
32
+ q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
33
+ k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
34
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
35
+ return attn_mask
36
+
37
+
38
+ class Attention(nn.Module):
39
+ def __init__(self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None,
40
+ attn_drop=0., proj_drop=0., use_rope=False):
41
+ super().__init__()
42
+ self.num_heads = num_heads
43
+ head_dim = dim // num_heads
44
+ self.scale = qk_scale or head_dim ** -0.5
45
+
46
+ context_dim = dim if context_dim is None else context_dim
47
+
48
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
49
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
50
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
51
+ self.attn_drop_p = attn_drop
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ self.use_rope = use_rope
57
+ if self.use_rope:
58
+ self.rotary = RotaryEmbedding(dim=head_dim)
59
+
60
+ def forward(self, x, context=None, context_mask=None):
61
+ B, L, C = x.shape
62
+ q = self.to_q(x)
63
+ if context is None:
64
+ context = x
65
+ else:
66
+ assert self.use_rope is False
67
+
68
+ k = self.to_k(context)
69
+ v = self.to_v(context)
70
+
71
+ if context_mask is not None:
72
+ mask_binary = create_mask(x, context, None, context_mask)
73
+ else:
74
+ mask_binary = None
75
+
76
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads).float()
77
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads).float()
78
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads).float()
79
+
80
+ if self.use_rope:
81
+ q, k = self.rotary(q=q, k=k)
82
+
83
+ if ATTENTION_MODE == 'flash':
84
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v,
85
+ dropout_p=self.attn_drop_p,
86
+ attn_mask=mask_binary)
87
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
88
+ elif ATTENTION_MODE == 'math':
89
+ attn = (q @ k.transpose(-2, -1)) * self.scale
90
+ attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
91
+ attn = attn.softmax(dim=-1)
92
+ attn = self.attn_drop(attn)
93
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
94
+ else:
95
+ raise NotImplementedError
96
+
97
+ x = self.proj(x)
98
+ x = self.proj_drop(x)
99
+ return x
src/models/utils/bk/llama_rotary.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+ from rotary import RotaryEmbedding
4
+ import time
5
+
6
+
7
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
8
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
9
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
10
+ freqs = torch.outer(t, freqs)
11
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
12
+ return freqs_cis
13
+
14
+
15
+ def reshape_for_broadcast(freqs_cis: torch.Tensor,
16
+ x: torch.Tensor,):
17
+ ndim = x.ndim
18
+ assert 0 <= 1 < ndim
19
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
20
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
21
+ return freqs_cis.view(*shape)
22
+
23
+
24
+ def compute_rope(q, freqs_cis):
25
+ return q * freqs_cis
26
+
27
+
28
+ def apply_rotary_emb(
29
+ xq: torch.Tensor,
30
+ xk: torch.Tensor,
31
+ freqs_cis: torch.Tensor,
32
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
33
+ # xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
34
+ # xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
35
+ xq1, xq2 = xq.chunk(2, dim=-1)
36
+ xq_ = torch.view_as_complex(torch.stack((xq1, xq2), dim=-1).float())
37
+
38
+ xk1, xk2 = xk.chunk(2, dim=-1)
39
+ xk_ = torch.view_as_complex(torch.stack((xk1, xk2), dim=-1).float())
40
+
41
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
42
+ xq_out = torch.view_as_real(compute_rope(xq_, freqs_cis)).flatten(3)
43
+ xk_out = torch.view_as_real(compute_rope(xk_, freqs_cis)).flatten(3)
44
+ return xq_out.type_as(xq), xk_out.type_as(xk)
45
+
46
+
47
+ if __name__ == '__main__':
48
+ # Move data to CUDA
49
+ freq_cis = precompute_freqs_cis(4, 5).cuda()
50
+ x = torch.rand(1, 5, 1, 4).cuda()
51
+ y = torch.rand(1, 5, 1, 4).cuda()
52
+
53
+ # First method
54
+ start_time = time.time()
55
+ for _ in range(20000):
56
+ x1, y1 = apply_rotary_emb(x, y, freq_cis)
57
+ end_time = time.time()
58
+ print(f"Method 1 time cost: {end_time - start_time} seconds")
59
+
60
+ # Prepare data for the second method
61
+ x = x.permute(0, 2, 1, 3)
62
+ y = y.permute(0, 2, 1, 3)
63
+ rope = RotaryEmbedding(4).cuda()
64
+
65
+ # Second method
66
+ start_time = time.time()
67
+ for _ in range(20000):
68
+ x2, y2 = rope(x, y)
69
+ end_time = time.time()
70
+ print(f"Method 2 time cost: {end_time - start_time} seconds")
71
+
72
+ # Print the results
73
+ print(x1)
74
+ print(x2)
src/models/utils/modules.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint
5
+ from torch.cuda.amp import autocast
6
+ import math
7
+ import einops
8
+ from einops import rearrange, repeat
9
+ from inspect import isfunction
10
+ from .timm import trunc_normal_
11
+
12
+
13
+ # disable in checkpoint mode
14
+ # @torch.jit.script
15
+ def film_modulate(x, shift, scale):
16
+ return x * (1 + scale) + shift
17
+
18
+
19
+ def timestep_embedding(timesteps, dim, max_period=10000):
20
+ """
21
+ Create sinusoidal timestep embeddings.
22
+
23
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
24
+ These may be fractional.
25
+ :param dim: the dimension of the output.
26
+ :param max_period: controls the minimum frequency of the embeddings.
27
+ :return: an [N x dim] Tensor of positional embeddings.
28
+ """
29
+ half = dim // 2
30
+ freqs = torch.exp(
31
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
32
+ ).to(device=timesteps.device)
33
+ args = timesteps[:, None].float() * freqs[None]
34
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
35
+ if dim % 2:
36
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
37
+ return embedding
38
+
39
+
40
+ class TimestepEmbedder(nn.Module):
41
+ """
42
+ Embeds scalar timesteps into vector representations.
43
+ """
44
+
45
+ def __init__(self, hidden_size, frequency_embedding_size=256,
46
+ out_size=None):
47
+ super().__init__()
48
+ if out_size is None:
49
+ out_size = hidden_size
50
+ self.mlp = nn.Sequential(
51
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
52
+ nn.SiLU(),
53
+ nn.Linear(hidden_size, out_size, bias=True),
54
+ )
55
+ self.frequency_embedding_size = frequency_embedding_size
56
+
57
+ def forward(self, t):
58
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
59
+ self.mlp[0].weight.dtype)
60
+ t_emb = self.mlp(t_freq)
61
+ return t_emb
62
+
63
+
64
+ def patchify(imgs, patch_size, input_type='2d'):
65
+ if input_type == '2d':
66
+ x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
67
+ elif input_type == '1d':
68
+ x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size)
69
+ return x
70
+
71
+
72
+ def unpatchify(x, channels=3, input_type='2d', img_size=None):
73
+ if input_type == '2d':
74
+ patch_size = int((x.shape[2] // channels) ** 0.5)
75
+ # h = w = int(x.shape[1] ** .5)
76
+ h, w = img_size[0] // patch_size, img_size[1] // patch_size
77
+ assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
78
+ x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h,
79
+ p1=patch_size, p2=patch_size)
80
+ elif input_type == '1d':
81
+ patch_size = int((x.shape[2] // channels))
82
+ h = x.shape[1]
83
+ assert patch_size * channels == x.shape[2]
84
+ x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size)
85
+ return x
86
+
87
+
88
+ class PatchEmbed(nn.Module):
89
+ """
90
+ Image to Patch Embedding
91
+ """
92
+
93
+ def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'):
94
+ super().__init__()
95
+ self.patch_size = patch_size
96
+ self.input_type = input_type
97
+ if input_type == '2d':
98
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
99
+ elif input_type == '1d':
100
+ self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
101
+
102
+ def forward(self, x):
103
+ if self.input_type == '2d':
104
+ B, C, H, W = x.shape
105
+ assert H % self.patch_size == 0 and W % self.patch_size == 0
106
+ elif self.input_type == '1d':
107
+ B, C, H = x.shape
108
+ assert H % self.patch_size == 0
109
+
110
+ x = self.proj(x).flatten(2).transpose(1, 2)
111
+ return x
112
+
113
+
114
+ class PositionalConvEmbedding(nn.Module):
115
+ """
116
+ Relative positional embedding used in HuBERT
117
+ """
118
+
119
+ def __init__(self, dim=768, kernel_size=128, groups=16):
120
+ super().__init__()
121
+ self.conv = nn.Conv1d(
122
+ dim,
123
+ dim,
124
+ kernel_size=kernel_size,
125
+ padding=kernel_size // 2,
126
+ groups=groups,
127
+ bias=True
128
+ )
129
+ self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
130
+
131
+ def forward(self, x):
132
+ # B C T
133
+ x = self.conv(x)
134
+ x = F.gelu(x[:, :, :-1])
135
+ return x
136
+
137
+
138
+ class SinusoidalPositionalEncoding(nn.Module):
139
+ def __init__(self, dim, length):
140
+ super(SinusoidalPositionalEncoding, self).__init__()
141
+ self.length = length
142
+ self.dim = dim
143
+ self.register_buffer('pe', self._generate_positional_encoding(length, dim))
144
+
145
+ def _generate_positional_encoding(self, length, dim):
146
+ pe = torch.zeros(length, dim)
147
+ position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
148
+ div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
149
+
150
+ pe[:, 0::2] = torch.sin(position * div_term)
151
+ pe[:, 1::2] = torch.cos(position * div_term)
152
+
153
+ pe = pe.unsqueeze(0)
154
+ return pe
155
+
156
+ def forward(self, x):
157
+ x = x + self.pe[:, :x.size(1)]
158
+ return x
159
+
160
+
161
+ class PE_wrapper(nn.Module):
162
+ def __init__(self, dim=768, method='abs', length=None, **kwargs):
163
+ super().__init__()
164
+ self.method = method
165
+ if method == 'abs':
166
+ # init absolute pe like UViT
167
+ self.length = length
168
+ self.abs_pe = nn.Parameter(torch.zeros(1, length, dim))
169
+ trunc_normal_(self.abs_pe, std=.02)
170
+ elif method == 'conv':
171
+ self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs)
172
+ elif method == 'sinu':
173
+ self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length)
174
+ elif method == 'none':
175
+ # skip pe
176
+ self.id = nn.Identity()
177
+ else:
178
+ raise NotImplementedError
179
+
180
+ def forward(self, x):
181
+ if self.method == 'abs':
182
+ _, L, _ = x.shape
183
+ assert L <= self.length
184
+ x = x + self.abs_pe[:, :L, :]
185
+ elif self.method == 'conv':
186
+ x = x + self.conv_pe(x)
187
+ elif self.method == 'sinu':
188
+ x = self.sinu_pe(x)
189
+ elif self.method == 'none':
190
+ x = self.id(x)
191
+ else:
192
+ raise NotImplementedError
193
+ return x
194
+
195
+
196
+ class RMSNorm(torch.nn.Module):
197
+ def __init__(self, dim: int, eps: float = 1e-6):
198
+ """
199
+ Initialize the RMSNorm normalization layer.
200
+
201
+ Args:
202
+ dim (int): The dimension of the input tensor.
203
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
204
+
205
+ Attributes:
206
+ eps (float): A small value added to the denominator for numerical stability.
207
+ weight (nn.Parameter): Learnable scaling parameter.
208
+
209
+ """
210
+ super().__init__()
211
+ self.eps = eps
212
+ self.weight = nn.Parameter(torch.ones(dim))
213
+
214
+ def _norm(self, x):
215
+ """
216
+ Apply the RMSNorm normalization to the input tensor.
217
+
218
+ Args:
219
+ x (torch.Tensor): The input tensor.
220
+
221
+ Returns:
222
+ torch.Tensor: The normalized tensor.
223
+
224
+ """
225
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
226
+
227
+ def forward(self, x):
228
+ """
229
+ Forward pass through the RMSNorm layer.
230
+
231
+ Args:
232
+ x (torch.Tensor): The input tensor.
233
+
234
+ Returns:
235
+ torch.Tensor: The output tensor after applying RMSNorm.
236
+
237
+ """
238
+ output = self._norm(x.float()).type_as(x)
239
+ return output * self.weight
240
+
241
+
242
+ class GELU(nn.Module):
243
+
244
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none",
245
+ bias: bool = True):
246
+ super().__init__()
247
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
248
+ self.approximate = approximate
249
+
250
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
251
+ if gate.device.type != "mps":
252
+ return F.gelu(gate, approximate=self.approximate)
253
+ # mps: gelu is not implemented for float16
254
+ return F.gelu(gate.to(dtype=torch.float32),
255
+ approximate=self.approximate).to(dtype=gate.dtype)
256
+
257
+ def forward(self, hidden_states):
258
+ hidden_states = self.proj(hidden_states)
259
+ hidden_states = self.gelu(hidden_states)
260
+ return hidden_states
261
+
262
+
263
+ class GEGLU(nn.Module):
264
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
265
+ super().__init__()
266
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
267
+
268
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
269
+ if gate.device.type != "mps":
270
+ return F.gelu(gate)
271
+ # mps: gelu is not implemented for float16
272
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
273
+
274
+ def forward(self, hidden_states):
275
+ hidden_states = self.proj(hidden_states)
276
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
277
+ return hidden_states * self.gelu(gate)
278
+
279
+
280
+ class ApproximateGELU(nn.Module):
281
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
282
+ super().__init__()
283
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
284
+
285
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
286
+ x = self.proj(x)
287
+ return x * torch.sigmoid(1.702 * x)
288
+
289
+
290
+ # disable in checkpoint mode
291
+ # @torch.jit.script
292
+ def snake_beta(x, alpha, beta):
293
+ return x + beta * torch.sin(x * alpha).pow(2)
294
+
295
+
296
+ class Snake(nn.Module):
297
+ def __init__(self, dim_in, dim_out, bias,
298
+ alpha_trainable=True):
299
+ super().__init__()
300
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
301
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
302
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
303
+ self.alpha.requires_grad = alpha_trainable
304
+ self.beta.requires_grad = alpha_trainable
305
+
306
+ def forward(self, x):
307
+ x = self.proj(x)
308
+ x = snake_beta(x, self.alpha, self.beta)
309
+ return x
310
+
311
+
312
+ class GESnake(nn.Module):
313
+ def __init__(self, dim_in, dim_out, bias,
314
+ alpha_trainable=True):
315
+ super().__init__()
316
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
317
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
318
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
319
+ self.alpha.requires_grad = alpha_trainable
320
+ self.beta.requires_grad = alpha_trainable
321
+
322
+ def forward(self, x):
323
+ x = self.proj(x)
324
+ x, gate = x.chunk(2, dim=-1)
325
+ return x * snake_beta(gate, self.alpha, self.beta)
326
+
327
+
328
+ class FeedForward(nn.Module):
329
+ def __init__(
330
+ self,
331
+ dim,
332
+ dim_out=None,
333
+ mult=4,
334
+ dropout=0.0,
335
+ activation_fn="geglu",
336
+ final_dropout=False,
337
+ inner_dim=None,
338
+ bias=True,
339
+ ):
340
+ super().__init__()
341
+ if inner_dim is None:
342
+ inner_dim = int(dim * mult)
343
+ dim_out = dim_out if dim_out is not None else dim
344
+
345
+ if activation_fn == "gelu":
346
+ act_fn = GELU(dim, inner_dim, bias=bias)
347
+ elif activation_fn == "gelu-approximate":
348
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
349
+ elif activation_fn == "geglu":
350
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
351
+ elif activation_fn == "geglu-approximate":
352
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
353
+ elif activation_fn == "snake":
354
+ act_fn = Snake(dim, inner_dim, bias=bias)
355
+ elif activation_fn == "gesnake":
356
+ act_fn = GESnake(dim, inner_dim, bias=bias)
357
+ else:
358
+ raise NotImplementedError
359
+
360
+ self.net = nn.ModuleList([])
361
+ # project in
362
+ self.net.append(act_fn)
363
+ # project dropout
364
+ self.net.append(nn.Dropout(dropout))
365
+ # project out
366
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
367
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
368
+ if final_dropout:
369
+ self.net.append(nn.Dropout(dropout))
370
+
371
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
372
+ for module in self.net:
373
+ hidden_states = module(hidden_states)
374
+ return hidden_states
src/models/utils/rotary.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ "this rope is faster than llama rope with jit script"
4
+
5
+
6
+ def rotate_half(x):
7
+ x1, x2 = x.chunk(2, dim=-1)
8
+ return torch.cat((-x2, x1), dim=-1)
9
+
10
+
11
+ # disable in checkpoint mode
12
+ # @torch.jit.script
13
+ def apply_rotary_pos_emb(x, cos, sin):
14
+ # NOTE: This could probably be moved to Triton
15
+ # Handle a possible sequence length mismatch in between q and k
16
+ cos = cos[:, :, : x.shape[-2], :]
17
+ sin = sin[:, :, : x.shape[-2], :]
18
+ return (x * cos) + (rotate_half(x) * sin)
19
+
20
+
21
+ class RotaryEmbedding(torch.nn.Module):
22
+ """
23
+ The rotary position embeddings from RoFormer_ (Su et. al).
24
+ A crucial insight from the method is that the query and keys are
25
+ transformed by rotation matrices which depend on the relative positions.
26
+
27
+ Other implementations are available in the Rotary Transformer repo_ and in
28
+ GPT-NeoX_, GPT-NeoX was an inspiration
29
+
30
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
31
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
32
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
33
+
34
+
35
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
36
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
37
+ """
38
+
39
+ def __init__(self, dim: int):
40
+ super().__init__()
41
+ # Generate and save the inverse frequency buffer (non trainable)
42
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
43
+ self.register_buffer("inv_freq", inv_freq)
44
+ self._seq_len_cached = None
45
+ self._cos_cached = None
46
+ self._sin_cached = None
47
+
48
+ def _update_cos_sin_tables(self, x, seq_dimension=-2):
49
+ # expect input: B, H, L, D
50
+ seq_len = x.shape[seq_dimension]
51
+
52
+ # Reset the tables if the sequence length has changed,
53
+ # or if we're on a new device (possibly due to tracing for instance)
54
+ # also make sure dtype wont change
55
+ if (
56
+ seq_len != self._seq_len_cached
57
+ or self._cos_cached.device != x.device
58
+ or self._cos_cached.dtype != x.dtype
59
+ ):
60
+ self._seq_len_cached = seq_len
61
+ t = torch.arange(
62
+ x.shape[seq_dimension], device=x.device, dtype=torch.float32
63
+ )
64
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
65
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
66
+
67
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
68
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
69
+
70
+ return self._cos_cached, self._sin_cached
71
+
72
+ def forward(self, q, k):
73
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
74
+ q.float(), seq_dimension=-2
75
+ )
76
+ if k is not None:
77
+ return (
78
+ apply_rotary_pos_emb(q.float(),
79
+ self._cos_cached,
80
+ self._sin_cached).type_as(q),
81
+ apply_rotary_pos_emb(k.float(),
82
+ self._cos_cached,
83
+ self._sin_cached).type_as(k),
84
+ )
85
+ else:
86
+ return (
87
+ apply_rotary_pos_emb(q.float(),
88
+ self._cos_cached,
89
+ self._sin_cached).type_as(q),
90
+ None
91
+ )
src/models/utils/span_mask.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from typing import Optional, Tuple
4
+
5
+
6
+ def compute_mask_indices(
7
+ shape: Tuple[int, int],
8
+ padding_mask: Optional[torch.Tensor],
9
+ mask_prob: float,
10
+ mask_length: int,
11
+ mask_type: str = "static",
12
+ mask_other: float = 0.0,
13
+ min_masks: int = 0,
14
+ no_overlap: bool = False,
15
+ min_space: int = 0,
16
+ ) -> np.ndarray:
17
+ """
18
+ Computes random mask spans for a given shape
19
+
20
+ Args:
21
+ shape: the the shape for which to compute masks.
22
+ should be of size 2 where first element is batch size and 2nd is timesteps
23
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
24
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
25
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
26
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
27
+ mask_type: how to compute mask lengths
28
+ static = fixed size
29
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
30
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
31
+ poisson = sample from possion distribution with lambda = mask length
32
+ min_masks: minimum number of masked spans
33
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
34
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
35
+ """
36
+
37
+ bsz, all_sz = shape
38
+ mask = np.full((bsz, all_sz), False)
39
+
40
+ # Convert mask_prob to a NumPy array
41
+ mask_prob = np.array(mask_prob)
42
+
43
+ # Calculate all_num_mask for each element in the batch
44
+ all_num_mask = np.floor(mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)).astype(int)
45
+
46
+ # Apply the max operation with min_masks for each element
47
+ all_num_mask = np.maximum(min_masks, all_num_mask)
48
+
49
+ mask_idcs = []
50
+ for i in range(bsz):
51
+ if padding_mask is not None:
52
+ sz = all_sz - padding_mask[i].long().sum().item()
53
+ num_mask = int(
54
+ # add a random number for probabilistic rounding
55
+ mask_prob * sz / float(mask_length)
56
+ + np.random.rand()
57
+ )
58
+ num_mask = max(min_masks, num_mask)
59
+ else:
60
+ sz = all_sz
61
+ num_mask = all_num_mask[i]
62
+
63
+ if mask_type == "static":
64
+ lengths = np.full(num_mask, mask_length)
65
+ elif mask_type == "uniform":
66
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
67
+ elif mask_type == "normal":
68
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
69
+ lengths = [max(1, int(round(x))) for x in lengths]
70
+ elif mask_type == "poisson":
71
+ lengths = np.random.poisson(mask_length, size=num_mask)
72
+ lengths = [int(round(x)) for x in lengths]
73
+ else:
74
+ raise Exception("unknown mask selection " + mask_type)
75
+
76
+ if sum(lengths) == 0:
77
+ lengths[0] = min(mask_length, sz - 1)
78
+
79
+ if no_overlap:
80
+ mask_idc = []
81
+
82
+ def arrange(s, e, length, keep_length):
83
+ span_start = np.random.randint(s, e - length)
84
+ mask_idc.extend(span_start + i for i in range(length))
85
+
86
+ new_parts = []
87
+ if span_start - s - min_space >= keep_length:
88
+ new_parts.append((s, span_start - min_space + 1))
89
+ if e - span_start - keep_length - min_space > keep_length:
90
+ new_parts.append((span_start + length + min_space, e))
91
+ return new_parts
92
+
93
+ parts = [(0, sz)]
94
+ min_length = min(lengths)
95
+ for length in sorted(lengths, reverse=True):
96
+ lens = np.fromiter(
97
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
98
+ np.int,
99
+ )
100
+ l_sum = np.sum(lens)
101
+ if l_sum == 0:
102
+ break
103
+ probs = lens / np.sum(lens)
104
+ c = np.random.choice(len(parts), p=probs)
105
+ s, e = parts.pop(c)
106
+ parts.extend(arrange(s, e, length, min_length))
107
+ mask_idc = np.asarray(mask_idc)
108
+ else:
109
+ min_len = min(lengths)
110
+ if sz - min_len <= num_mask:
111
+ min_len = sz - num_mask - 1
112
+
113
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
114
+
115
+ mask_idc = np.asarray(
116
+ [
117
+ mask_idc[j] + offset
118
+ for j in range(len(mask_idc))
119
+ for offset in range(lengths[j])
120
+ ]
121
+ )
122
+
123
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
124
+ # min_len = min([len(m) for m in mask_idcs])
125
+ for i, mask_idc in enumerate(mask_idcs):
126
+ # if len(mask_idc) > min_len:
127
+ # mask_idc = np.random.choice(mask_idc, min_len, replace=False)
128
+ mask[i, mask_idc] = True
129
+
130
+ return torch.tensor(mask)
131
+
132
+
133
+ if __name__ == '__main__':
134
+ mask = compute_mask_indices(
135
+ shape=[4, 500],
136
+ padding_mask=None,
137
+ mask_prob=[0.65, 0.5, 0.65, 0.65],
138
+ mask_length=10,
139
+ mask_type="static",
140
+ mask_other=0.0,
141
+ min_masks=1,
142
+ no_overlap=False,
143
+ min_space=0,
144
+ )
145
+ print(mask)
146
+ print(mask.sum(dim=1))
src/models/utils/timm.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code from timm 0.3.2
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ import warnings
6
+
7
+
8
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
9
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
10
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11
+ def norm_cdf(x):
12
+ # Computes standard normal cumulative distribution function
13
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
14
+
15
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
16
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
17
+ "The distribution of values may be incorrect.",
18
+ stacklevel=2)
19
+
20
+ with torch.no_grad():
21
+ # Values are generated by using a truncated uniform distribution and
22
+ # then using the inverse CDF for the normal distribution.
23
+ # Get upper and lower cdf values
24
+ l = norm_cdf((a - mean) / std)
25
+ u = norm_cdf((b - mean) / std)
26
+
27
+ # Uniformly fill tensor with values from [l, u], then translate to
28
+ # [2l-1, 2u-1].
29
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
30
+
31
+ # Use inverse cdf transform for normal distribution to get truncated
32
+ # standard normal
33
+ tensor.erfinv_()
34
+
35
+ # Transform to proper mean, std
36
+ tensor.mul_(std * math.sqrt(2.))
37
+ tensor.add_(mean)
38
+
39
+ # Clamp to ensure it's in the proper range
40
+ tensor.clamp_(min=a, max=b)
41
+ return tensor
42
+
43
+
44
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
45
+ # type: (Tensor, float, float, float, float) -> Tensor
46
+ r"""Fills the input Tensor with values drawn from a truncated
47
+ normal distribution. The values are effectively drawn from the
48
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
49
+ with values outside :math:`[a, b]` redrawn until they are within
50
+ the bounds. The method used for generating the random values works
51
+ best when :math:`a \leq \text{mean} \leq b`.
52
+ Args:
53
+ tensor: an n-dimensional `torch.Tensor`
54
+ mean: the mean of the normal distribution
55
+ std: the standard deviation of the normal distribution
56
+ a: the minimum cutoff value
57
+ b: the maximum cutoff value
58
+ Examples:
59
+ >>> w = torch.empty(3, 5)
60
+ >>> nn.init.trunc_normal_(w)
61
+ """
62
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
63
+
64
+
65
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
66
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
67
+
68
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
69
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
70
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
71
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
72
+ 'survival rate' as the argument.
73
+
74
+ """
75
+ if drop_prob == 0. or not training:
76
+ return x
77
+ keep_prob = 1 - drop_prob
78
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
79
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
80
+ random_tensor.floor_() # binarize
81
+ output = x.div(keep_prob) * random_tensor
82
+ return output
83
+
84
+
85
+ class DropPath(nn.Module):
86
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
87
+ """
88
+
89
+ def __init__(self, drop_prob=None):
90
+ super(DropPath, self).__init__()
91
+ self.drop_prob = drop_prob
92
+
93
+ def forward(self, x):
94
+ return drop_path(x, self.drop_prob, self.training)
95
+
96
+
97
+ class Mlp(nn.Module):
98
+ def __init__(self, in_features, hidden_features=None, out_features=None,
99
+ act_layer=nn.GELU, drop=0.):
100
+ super().__init__()
101
+ out_features = out_features or in_features
102
+ hidden_features = hidden_features or in_features
103
+ self.fc1 = nn.Linear(in_features, hidden_features)
104
+ self.act = act_layer()
105
+ self.fc2 = nn.Linear(hidden_features, out_features)
106
+ self.drop = nn.Dropout(drop)
107
+
108
+ def forward(self, x):
109
+ x = self.fc1(x)
110
+ x = self.act(x)
111
+ x = self.drop(x)
112
+ x = self.fc2(x)
113
+ x = self.drop(x)
114
+ return x
src/modules/autoencoder_wrapper.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .dac import DAC
4
+ from .stable_vae import load_vae
5
+
6
+
7
+ class Autoencoder(nn.Module):
8
+ def __init__(self, ckpt_path, model_type='dac', quantization_first=False):
9
+ super(Autoencoder, self).__init__()
10
+ self.model_type = model_type
11
+ if self.model_type == 'dac':
12
+ model = DAC.load(ckpt_path)
13
+ elif self.model_type == 'stable_vae':
14
+ model = load_vae(ckpt_path)
15
+ else:
16
+ raise NotImplementedError(f"Model type not implemented: {self.model_type}")
17
+ self.ae = model.eval()
18
+ self.quantization_first = quantization_first
19
+ print(f'Autoencoder quantization first mode: {quantization_first}')
20
+
21
+ @torch.no_grad()
22
+ def forward(self, audio=None, embedding=None):
23
+ if self.model_type == 'dac':
24
+ return self.process_dac(audio, embedding)
25
+ elif self.model_type == 'encodec':
26
+ return self.process_encodec(audio, embedding)
27
+ elif self.model_type == 'stable_vae':
28
+ return self.process_stable_vae(audio, embedding)
29
+ else:
30
+ raise NotImplementedError(f"Model type not implemented: {self.model_type}")
31
+
32
+ def process_dac(self, audio=None, embedding=None):
33
+ if audio is not None:
34
+ z = self.ae.encoder(audio)
35
+ if self.quantization_first:
36
+ z, *_ = self.ae.quantizer(z, None)
37
+ return z
38
+ elif embedding is not None:
39
+ z = embedding
40
+ if self.quantization_first:
41
+ audio = self.ae.decoder(z)
42
+ else:
43
+ z, *_ = self.ae.quantizer(z, None)
44
+ audio = self.ae.decoder(z)
45
+ return audio
46
+ else:
47
+ raise ValueError("Either audio or embedding must be provided.")
48
+
49
+ def process_encodec(self, audio=None, embedding=None):
50
+ if audio is not None:
51
+ z = self.ae.encoder(audio)
52
+ if self.quantization_first:
53
+ code = self.ae.quantizer.encode(z)
54
+ z = self.ae.quantizer.decode(code)
55
+ return z
56
+ elif embedding is not None:
57
+ z = embedding
58
+ if self.quantization_first:
59
+ audio = self.ae.decoder(z)
60
+ else:
61
+ code = self.ae.quantizer.encode(z)
62
+ z = self.ae.quantizer.decode(code)
63
+ audio = self.ae.decoder(z)
64
+ return audio
65
+ else:
66
+ raise ValueError("Either audio or embedding must be provided.")
67
+
68
+ def process_stable_vae(self, audio=None, embedding=None):
69
+ if audio is not None:
70
+ z = self.ae.encoder(audio)
71
+ if self.quantization_first:
72
+ z = self.ae.bottleneck.encode(z)
73
+ return z
74
+ if embedding is not None:
75
+ z = embedding
76
+ if self.quantization_first:
77
+ audio = self.ae.decoder(z)
78
+ else:
79
+ z = self.ae.bottleneck.encode(z)
80
+ audio = self.ae.decoder(z)
81
+ return audio
82
+ else:
83
+ raise ValueError("Either audio or embedding must be provided.")
src/modules/clap_wrapper.py ADDED
File without changes
src/modules/dac/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "1.0.0"
2
+
3
+ # preserved here for legacy reasons
4
+ __model_version__ = "latest"
5
+
6
+ import audiotools
7
+
8
+ audiotools.ml.BaseModel.INTERN += ["dac.**"]
9
+ audiotools.ml.BaseModel.EXTERN += ["einops"]
10
+
11
+
12
+ from . import nn
13
+ from . import model
14
+ from . import utils
15
+ from .model import DAC
16
+ from .model import DACFile
src/modules/dac/__main__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import argbind
4
+
5
+ from dac.utils import download
6
+ from dac.utils.decode import decode
7
+ from dac.utils.encode import encode
8
+
9
+ STAGES = ["encode", "decode", "download"]
10
+
11
+
12
+ def run(stage: str):
13
+ """Run stages.
14
+
15
+ Parameters
16
+ ----------
17
+ stage : str
18
+ Stage to run
19
+ """
20
+ if stage not in STAGES:
21
+ raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22
+ stage_fn = globals()[stage]
23
+
24
+ if stage == "download":
25
+ stage_fn()
26
+ return
27
+
28
+ stage_fn()
29
+
30
+
31
+ if __name__ == "__main__":
32
+ group = sys.argv.pop(1)
33
+ args = argbind.parse_args(group=group)
34
+
35
+ with argbind.scope(args):
36
+ run(group)
src/modules/dac/compare/__init__.py ADDED
File without changes
src/modules/dac/compare/encodec.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from audiotools import AudioSignal
3
+ from audiotools.ml import BaseModel
4
+ from encodec import EncodecModel
5
+
6
+
7
+ class Encodec(BaseModel):
8
+ def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0):
9
+ super().__init__()
10
+
11
+ if sample_rate == 24000:
12
+ self.model = EncodecModel.encodec_model_24khz()
13
+ else:
14
+ self.model = EncodecModel.encodec_model_48khz()
15
+ self.model.set_target_bandwidth(bandwidth)
16
+ self.sample_rate = 44100
17
+
18
+ def forward(
19
+ self,
20
+ audio_data: torch.Tensor,
21
+ sample_rate: int = 44100,
22
+ n_quantizers: int = None,
23
+ ):
24
+ signal = AudioSignal(audio_data, sample_rate)
25
+ signal.resample(self.model.sample_rate)
26
+ recons = self.model(signal.audio_data)
27
+ recons = AudioSignal(recons, self.model.sample_rate)
28
+ recons.resample(sample_rate)
29
+ return {"audio": recons.audio_data}
30
+
31
+
32
+ if __name__ == "__main__":
33
+ import numpy as np
34
+ from functools import partial
35
+
36
+ model = Encodec()
37
+
38
+ for n, m in model.named_modules():
39
+ o = m.extra_repr()
40
+ p = sum([np.prod(p.size()) for p in m.parameters()])
41
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
42
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
43
+ print(model)
44
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
45
+
46
+ length = 88200 * 2
47
+ x = torch.randn(1, 1, length).to(model.device)
48
+ x.requires_grad_(True)
49
+ x.retain_grad()
50
+
51
+ # Make a forward pass
52
+ out = model(x)["audio"]
53
+
54
+ print(x.shape, out.shape)
src/modules/dac/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base import CodecMixin
2
+ from .base import DACFile
3
+ from .dac import DAC
4
+ from .discriminator import Discriminator
src/modules/dac/model/base.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from audiotools import AudioSignal
10
+ from torch import nn
11
+
12
+ SUPPORTED_VERSIONS = ["1.0.0"]
13
+
14
+
15
+ @dataclass
16
+ class DACFile:
17
+ codes: torch.Tensor
18
+
19
+ # Metadata
20
+ chunk_length: int
21
+ original_length: int
22
+ input_db: float
23
+ channels: int
24
+ sample_rate: int
25
+ padding: bool
26
+ dac_version: str
27
+
28
+ def save(self, path):
29
+ artifacts = {
30
+ "codes": self.codes.numpy().astype(np.uint16),
31
+ "metadata": {
32
+ "input_db": self.input_db.numpy().astype(np.float32),
33
+ "original_length": self.original_length,
34
+ "sample_rate": self.sample_rate,
35
+ "chunk_length": self.chunk_length,
36
+ "channels": self.channels,
37
+ "padding": self.padding,
38
+ "dac_version": SUPPORTED_VERSIONS[-1],
39
+ },
40
+ }
41
+ path = Path(path).with_suffix(".dac")
42
+ with open(path, "wb") as f:
43
+ np.save(f, artifacts)
44
+ return path
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ artifacts = np.load(path, allow_pickle=True)[()]
49
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
50
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
+ raise RuntimeError(
52
+ f"Given file {path} can't be loaded with this version of descript-audio-codec."
53
+ )
54
+ return cls(codes=codes, **artifacts["metadata"])
55
+
56
+
57
+ class CodecMixin:
58
+ @property
59
+ def padding(self):
60
+ if not hasattr(self, "_padding"):
61
+ self._padding = True
62
+ return self._padding
63
+
64
+ @padding.setter
65
+ def padding(self, value):
66
+ assert isinstance(value, bool)
67
+
68
+ layers = [
69
+ l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
70
+ ]
71
+
72
+ for layer in layers:
73
+ if value:
74
+ if hasattr(layer, "original_padding"):
75
+ layer.padding = layer.original_padding
76
+ else:
77
+ layer.original_padding = layer.padding
78
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
79
+
80
+ self._padding = value
81
+
82
+ def get_delay(self):
83
+ # Any number works here, delay is invariant to input length
84
+ l_out = self.get_output_length(0)
85
+ L = l_out
86
+
87
+ layers = []
88
+ for layer in self.modules():
89
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
90
+ layers.append(layer)
91
+
92
+ for layer in reversed(layers):
93
+ d = layer.dilation[0]
94
+ k = layer.kernel_size[0]
95
+ s = layer.stride[0]
96
+
97
+ if isinstance(layer, nn.ConvTranspose1d):
98
+ L = ((L - d * (k - 1) - 1) / s) + 1
99
+ elif isinstance(layer, nn.Conv1d):
100
+ L = (L - 1) * s + d * (k - 1) + 1
101
+
102
+ L = math.ceil(L)
103
+
104
+ l_in = L
105
+
106
+ return (l_in - l_out) // 2
107
+
108
+ def get_output_length(self, input_length):
109
+ L = input_length
110
+ # Calculate output length
111
+ for layer in self.modules():
112
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113
+ d = layer.dilation[0]
114
+ k = layer.kernel_size[0]
115
+ s = layer.stride[0]
116
+
117
+ if isinstance(layer, nn.Conv1d):
118
+ L = ((L - d * (k - 1) - 1) / s) + 1
119
+ elif isinstance(layer, nn.ConvTranspose1d):
120
+ L = (L - 1) * s + d * (k - 1) + 1
121
+
122
+ L = math.floor(L)
123
+ return L
124
+
125
+ @torch.no_grad()
126
+ def compress(
127
+ self,
128
+ audio_path_or_signal: Union[str, Path, AudioSignal],
129
+ win_duration: float = 1.0,
130
+ verbose: bool = False,
131
+ normalize_db: float = -16,
132
+ n_quantizers: int = None,
133
+ ) -> DACFile:
134
+ """Processes an audio signal from a file or AudioSignal object into
135
+ discrete codes. This function processes the signal in short windows,
136
+ using constant GPU memory.
137
+
138
+ Parameters
139
+ ----------
140
+ audio_path_or_signal : Union[str, Path, AudioSignal]
141
+ audio signal to reconstruct
142
+ win_duration : float, optional
143
+ window duration in seconds, by default 5.0
144
+ verbose : bool, optional
145
+ by default False
146
+ normalize_db : float, optional
147
+ normalize db, by default -16
148
+
149
+ Returns
150
+ -------
151
+ DACFile
152
+ Object containing compressed codes and metadata
153
+ required for decompression
154
+ """
155
+ audio_signal = audio_path_or_signal
156
+ if isinstance(audio_signal, (str, Path)):
157
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
158
+
159
+ self.eval()
160
+ original_padding = self.padding
161
+ original_device = audio_signal.device
162
+
163
+ audio_signal = audio_signal.clone()
164
+ original_sr = audio_signal.sample_rate
165
+
166
+ resample_fn = audio_signal.resample
167
+ loudness_fn = audio_signal.loudness
168
+
169
+ # If audio is > 10 minutes long, use the ffmpeg versions
170
+ if audio_signal.signal_duration >= 10 * 60 * 60:
171
+ resample_fn = audio_signal.ffmpeg_resample
172
+ loudness_fn = audio_signal.ffmpeg_loudness
173
+
174
+ original_length = audio_signal.signal_length
175
+ resample_fn(self.sample_rate)
176
+ input_db = loudness_fn()
177
+
178
+ if normalize_db is not None:
179
+ audio_signal.normalize(normalize_db)
180
+ audio_signal.ensure_max_of_audio()
181
+
182
+ nb, nac, nt = audio_signal.audio_data.shape
183
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
184
+ win_duration = (
185
+ audio_signal.signal_duration if win_duration is None else win_duration
186
+ )
187
+
188
+ if audio_signal.signal_duration <= win_duration:
189
+ # Unchunked compression (used if signal length < win duration)
190
+ self.padding = True
191
+ n_samples = nt
192
+ hop = nt
193
+ else:
194
+ # Chunked inference
195
+ self.padding = False
196
+ # Zero-pad signal on either side by the delay
197
+ audio_signal.zero_pad(self.delay, self.delay)
198
+ n_samples = int(win_duration * self.sample_rate)
199
+ # Round n_samples to nearest hop length multiple
200
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
201
+ hop = self.get_output_length(n_samples)
202
+
203
+ codes = []
204
+ range_fn = range if not verbose else tqdm.trange
205
+
206
+ for i in range_fn(0, nt, hop):
207
+ x = audio_signal[..., i : i + n_samples]
208
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
209
+
210
+ audio_data = x.audio_data.to(self.device)
211
+ audio_data = self.preprocess(audio_data, self.sample_rate)
212
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
213
+ codes.append(c.to(original_device))
214
+ chunk_length = c.shape[-1]
215
+
216
+ codes = torch.cat(codes, dim=-1)
217
+
218
+ dac_file = DACFile(
219
+ codes=codes,
220
+ chunk_length=chunk_length,
221
+ original_length=original_length,
222
+ input_db=input_db,
223
+ channels=nac,
224
+ sample_rate=original_sr,
225
+ padding=self.padding,
226
+ dac_version=SUPPORTED_VERSIONS[-1],
227
+ )
228
+
229
+ if n_quantizers is not None:
230
+ codes = codes[:, :n_quantizers, :]
231
+
232
+ self.padding = original_padding
233
+ return dac_file
234
+
235
+ @torch.no_grad()
236
+ def decompress(
237
+ self,
238
+ obj: Union[str, Path, DACFile],
239
+ verbose: bool = False,
240
+ ) -> AudioSignal:
241
+ """Reconstruct audio from a given .dac file
242
+
243
+ Parameters
244
+ ----------
245
+ obj : Union[str, Path, DACFile]
246
+ .dac file location or corresponding DACFile object.
247
+ verbose : bool, optional
248
+ Prints progress if True, by default False
249
+
250
+ Returns
251
+ -------
252
+ AudioSignal
253
+ Object with the reconstructed audio
254
+ """
255
+ self.eval()
256
+ if isinstance(obj, (str, Path)):
257
+ obj = DACFile.load(obj)
258
+
259
+ original_padding = self.padding
260
+ self.padding = obj.padding
261
+
262
+ range_fn = range if not verbose else tqdm.trange
263
+ codes = obj.codes
264
+ original_device = codes.device
265
+ chunk_length = obj.chunk_length
266
+ recons = []
267
+
268
+ for i in range_fn(0, codes.shape[-1], chunk_length):
269
+ c = codes[..., i : i + chunk_length].to(self.device)
270
+ z = self.quantizer.from_codes(c)[0]
271
+ r = self.decode(z)
272
+ recons.append(r.to(original_device))
273
+
274
+ recons = torch.cat(recons, dim=-1)
275
+ recons = AudioSignal(recons, self.sample_rate)
276
+
277
+ resample_fn = recons.resample
278
+ loudness_fn = recons.loudness
279
+
280
+ # If audio is > 10 minutes long, use the ffmpeg versions
281
+ if recons.signal_duration >= 10 * 60 * 60:
282
+ resample_fn = recons.ffmpeg_resample
283
+ loudness_fn = recons.ffmpeg_loudness
284
+
285
+ recons.normalize(obj.input_db)
286
+ resample_fn(obj.sample_rate)
287
+ recons = recons[..., : obj.original_length]
288
+ loudness_fn()
289
+ recons.audio_data = recons.audio_data.reshape(
290
+ -1, obj.channels, obj.original_length
291
+ )
292
+
293
+ self.padding = original_padding
294
+ return recons
src/modules/dac/model/dac.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from audiotools.ml import BaseModel
9
+ from torch import nn
10
+
11
+ from .base import CodecMixin
12
+ from ..nn.layers import Snake1d
13
+ from ..nn.layers import WNConv1d
14
+ from ..nn.layers import WNConvTranspose1d
15
+ from ..nn.quantize import ResidualVectorQuantize
16
+
17
+
18
+ def init_weights(m):
19
+ if isinstance(m, nn.Conv1d):
20
+ nn.init.trunc_normal_(m.weight, std=0.02)
21
+ nn.init.constant_(m.bias, 0)
22
+
23
+
24
+ class ResidualUnit(nn.Module):
25
+ def __init__(self, dim: int = 16, dilation: int = 1):
26
+ super().__init__()
27
+ pad = ((7 - 1) * dilation) // 2
28
+ self.block = nn.Sequential(
29
+ Snake1d(dim),
30
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
31
+ Snake1d(dim),
32
+ WNConv1d(dim, dim, kernel_size=1),
33
+ )
34
+
35
+ def forward(self, x):
36
+ y = self.block(x)
37
+ pad = (x.shape[-1] - y.shape[-1]) // 2
38
+ if pad > 0:
39
+ x = x[..., pad:-pad]
40
+ return x + y
41
+
42
+
43
+ class EncoderBlock(nn.Module):
44
+ def __init__(self, dim: int = 16, stride: int = 1):
45
+ super().__init__()
46
+ self.block = nn.Sequential(
47
+ ResidualUnit(dim // 2, dilation=1),
48
+ ResidualUnit(dim // 2, dilation=3),
49
+ ResidualUnit(dim // 2, dilation=9),
50
+ Snake1d(dim // 2),
51
+ WNConv1d(
52
+ dim // 2,
53
+ dim,
54
+ kernel_size=2 * stride,
55
+ stride=stride,
56
+ padding=math.ceil(stride / 2),
57
+ ),
58
+ )
59
+
60
+ def forward(self, x):
61
+ return self.block(x)
62
+
63
+
64
+ class Encoder(nn.Module):
65
+ def __init__(
66
+ self,
67
+ d_model: int = 64,
68
+ strides: list = [2, 4, 8, 8],
69
+ d_latent: int = 64,
70
+ ):
71
+ super().__init__()
72
+ # Create first convolution
73
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
74
+
75
+ # Create EncoderBlocks that double channels as they downsample by `stride`
76
+ for stride in strides:
77
+ d_model *= 2
78
+ self.block += [EncoderBlock(d_model, stride=stride)]
79
+
80
+ # Create last convolution
81
+ self.block += [
82
+ Snake1d(d_model),
83
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
84
+ ]
85
+
86
+ # Wrap black into nn.Sequential
87
+ self.block = nn.Sequential(*self.block)
88
+ self.enc_dim = d_model
89
+
90
+ def forward(self, x):
91
+ return self.block(x)
92
+
93
+
94
+ class DecoderBlock(nn.Module):
95
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
96
+ super().__init__()
97
+ self.block = nn.Sequential(
98
+ Snake1d(input_dim),
99
+ WNConvTranspose1d(
100
+ input_dim,
101
+ output_dim,
102
+ kernel_size=2 * stride,
103
+ stride=stride,
104
+ padding=math.ceil(stride / 2),
105
+ ),
106
+ ResidualUnit(output_dim, dilation=1),
107
+ ResidualUnit(output_dim, dilation=3),
108
+ ResidualUnit(output_dim, dilation=9),
109
+ )
110
+
111
+ def forward(self, x):
112
+ return self.block(x)
113
+
114
+
115
+ class Decoder(nn.Module):
116
+ def __init__(
117
+ self,
118
+ input_channel,
119
+ channels,
120
+ rates,
121
+ d_out: int = 1,
122
+ ):
123
+ super().__init__()
124
+
125
+ # Add first conv layer
126
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
127
+
128
+ # Add upsampling + MRF blocks
129
+ for i, stride in enumerate(rates):
130
+ input_dim = channels // 2**i
131
+ output_dim = channels // 2 ** (i + 1)
132
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
133
+
134
+ # Add final conv layer
135
+ layers += [
136
+ Snake1d(output_dim),
137
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
138
+ nn.Tanh(),
139
+ ]
140
+
141
+ self.model = nn.Sequential(*layers)
142
+
143
+ def forward(self, x):
144
+ return self.model(x)
145
+
146
+
147
+ class DAC(BaseModel, CodecMixin):
148
+ def __init__(
149
+ self,
150
+ encoder_dim: int = 64,
151
+ encoder_rates: List[int] = [2, 4, 8, 8],
152
+ latent_dim: int = None,
153
+ decoder_dim: int = 1536,
154
+ decoder_rates: List[int] = [8, 8, 4, 2],
155
+ n_codebooks: int = 9,
156
+ codebook_size: int = 1024,
157
+ codebook_dim: Union[int, list] = 8,
158
+ quantizer_dropout: bool = False,
159
+ sample_rate: int = 44100,
160
+ ):
161
+ super().__init__()
162
+
163
+ self.encoder_dim = encoder_dim
164
+ self.encoder_rates = encoder_rates
165
+ self.decoder_dim = decoder_dim
166
+ self.decoder_rates = decoder_rates
167
+ self.sample_rate = sample_rate
168
+
169
+ if latent_dim is None:
170
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
171
+
172
+ self.latent_dim = latent_dim
173
+
174
+ self.hop_length = np.prod(encoder_rates)
175
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
176
+
177
+ self.n_codebooks = n_codebooks
178
+ self.codebook_size = codebook_size
179
+ self.codebook_dim = codebook_dim
180
+ self.quantizer = ResidualVectorQuantize(
181
+ input_dim=latent_dim,
182
+ n_codebooks=n_codebooks,
183
+ codebook_size=codebook_size,
184
+ codebook_dim=codebook_dim,
185
+ quantizer_dropout=quantizer_dropout,
186
+ )
187
+
188
+ self.decoder = Decoder(
189
+ latent_dim,
190
+ decoder_dim,
191
+ decoder_rates,
192
+ )
193
+ self.sample_rate = sample_rate
194
+ self.apply(init_weights)
195
+
196
+ self.delay = self.get_delay()
197
+
198
+ def preprocess(self, audio_data, sample_rate):
199
+ if sample_rate is None:
200
+ sample_rate = self.sample_rate
201
+ assert sample_rate == self.sample_rate
202
+
203
+ length = audio_data.shape[-1]
204
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
205
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
206
+
207
+ return audio_data
208
+
209
+ def encode(
210
+ self,
211
+ audio_data: torch.Tensor,
212
+ n_quantizers: int = None,
213
+ ):
214
+ """Encode given audio data and return quantized latent codes
215
+
216
+ Parameters
217
+ ----------
218
+ audio_data : Tensor[B x 1 x T]
219
+ Audio data to encode
220
+ n_quantizers : int, optional
221
+ Number of quantizers to use, by default None
222
+ If None, all quantizers are used.
223
+
224
+ Returns
225
+ -------
226
+ dict
227
+ A dictionary with the following keys:
228
+ "z" : Tensor[B x D x T]
229
+ Quantized continuous representation of input
230
+ "codes" : Tensor[B x N x T]
231
+ Codebook indices for each codebook
232
+ (quantized discrete representation of input)
233
+ "latents" : Tensor[B x N*D x T]
234
+ Projected latents (continuous representation of input before quantization)
235
+ "vq/commitment_loss" : Tensor[1]
236
+ Commitment loss to train encoder to predict vectors closer to codebook
237
+ entries
238
+ "vq/codebook_loss" : Tensor[1]
239
+ Codebook loss to update the codebook
240
+ "length" : int
241
+ Number of samples in input audio
242
+ """
243
+ z = self.encoder(audio_data)
244
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
245
+ z, n_quantizers
246
+ )
247
+ return z, codes, latents, commitment_loss, codebook_loss
248
+
249
+ def decode(self, z: torch.Tensor):
250
+ """Decode given latent codes and return audio data
251
+
252
+ Parameters
253
+ ----------
254
+ z : Tensor[B x D x T]
255
+ Quantized continuous representation of input
256
+ length : int, optional
257
+ Number of samples in output audio, by default None
258
+
259
+ Returns
260
+ -------
261
+ dict
262
+ A dictionary with the following keys:
263
+ "audio" : Tensor[B x 1 x length]
264
+ Decoded audio data.
265
+ """
266
+ return self.decoder(z)
267
+
268
+ def forward(
269
+ self,
270
+ audio_data: torch.Tensor,
271
+ sample_rate: int = None,
272
+ n_quantizers: int = None,
273
+ ):
274
+ """Model forward pass
275
+
276
+ Parameters
277
+ ----------
278
+ audio_data : Tensor[B x 1 x T]
279
+ Audio data to encode
280
+ sample_rate : int, optional
281
+ Sample rate of audio data in Hz, by default None
282
+ If None, defaults to `self.sample_rate`
283
+ n_quantizers : int, optional
284
+ Number of quantizers to use, by default None.
285
+ If None, all quantizers are used.
286
+
287
+ Returns
288
+ -------
289
+ dict
290
+ A dictionary with the following keys:
291
+ "z" : Tensor[B x D x T]
292
+ Quantized continuous representation of input
293
+ "codes" : Tensor[B x N x T]
294
+ Codebook indices for each codebook
295
+ (quantized discrete representation of input)
296
+ "latents" : Tensor[B x N*D x T]
297
+ Projected latents (continuous representation of input before quantization)
298
+ "vq/commitment_loss" : Tensor[1]
299
+ Commitment loss to train encoder to predict vectors closer to codebook
300
+ entries
301
+ "vq/codebook_loss" : Tensor[1]
302
+ Codebook loss to update the codebook
303
+ "length" : int
304
+ Number of samples in input audio
305
+ "audio" : Tensor[B x 1 x length]
306
+ Decoded audio data.
307
+ """
308
+ length = audio_data.shape[-1]
309
+ audio_data = self.preprocess(audio_data, sample_rate)
310
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(
311
+ audio_data, n_quantizers
312
+ )
313
+
314
+ x = self.decode(z)
315
+ return {
316
+ "audio": x[..., :length],
317
+ "z": z,
318
+ "codes": codes,
319
+ "latents": latents,
320
+ "vq/commitment_loss": commitment_loss,
321
+ "vq/codebook_loss": codebook_loss,
322
+ }
323
+
324
+
325
+ if __name__ == "__main__":
326
+ import numpy as np
327
+ from functools import partial
328
+
329
+ model = DAC().to("cpu")
330
+
331
+ for n, m in model.named_modules():
332
+ o = m.extra_repr()
333
+ p = sum([np.prod(p.size()) for p in m.parameters()])
334
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
335
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
336
+ print(model)
337
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
338
+
339
+ length = 88200 * 2
340
+ x = torch.randn(1, 1, length).to(model.device)
341
+ x.requires_grad_(True)
342
+ x.retain_grad()
343
+
344
+ # Make a forward pass
345
+ out = model(x)["audio"]
346
+ print("Input shape:", x.shape)
347
+ print("Output shape:", out.shape)
348
+
349
+ # Create gradient variable
350
+ grad = torch.zeros_like(out)
351
+ grad[:, :, grad.shape[-1] // 2] = 1
352
+
353
+ # Make a backward pass
354
+ out.backward(grad)
355
+
356
+ # Check non-zero values
357
+ gradmap = x.grad.squeeze(0)
358
+ gradmap = (gradmap != 0).sum(0) # sum across features
359
+ rf = (gradmap != 0).sum()
360
+
361
+ print(f"Receptive field: {rf.item()}")
362
+
363
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
364
+ model.decompress(model.compress(x, verbose=True), verbose=True)