reach-vb HF staff adefossez commited on
Commit
5325fcc
·
1 Parent(s): ed87f04

Stereo demo update (#60)

Browse files

- updated demo (a16e65ea0e58cfa993cbd11f06f2ded8eafb559c)


Co-authored-by: Alexandre Défossez <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/actions/audiocraft_build/action.yml +2 -0
  2. .github/workflows/audiocraft_docs.yml +3 -3
  3. .github/workflows/audiocraft_tests.yml +6 -1
  4. .gitignore +8 -1
  5. CHANGELOG.md +31 -1
  6. CONTRIBUTING.md +2 -2
  7. LICENSE_weights +399 -157
  8. MANIFEST.in +7 -0
  9. Makefile +23 -4
  10. README.md +43 -83
  11. assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 +0 -0
  12. assets/sirens_and_a_humming_engine_approach_and_pass.mp3 +0 -0
  13. audiocraft/__init__.py +17 -1
  14. audiocraft/adversarial/__init__.py +22 -0
  15. audiocraft/adversarial/discriminators/__init__.py +10 -0
  16. audiocraft/adversarial/discriminators/base.py +34 -0
  17. audiocraft/adversarial/discriminators/mpd.py +106 -0
  18. audiocraft/adversarial/discriminators/msd.py +126 -0
  19. audiocraft/adversarial/discriminators/msstftd.py +134 -0
  20. audiocraft/adversarial/losses.py +228 -0
  21. audiocraft/data/__init__.py +3 -1
  22. audiocraft/data/audio.py +37 -21
  23. audiocraft/data/audio_dataset.py +93 -31
  24. audiocraft/data/audio_utils.py +12 -10
  25. audiocraft/data/info_audio_dataset.py +110 -0
  26. audiocraft/data/music_dataset.py +270 -0
  27. audiocraft/data/sound_dataset.py +330 -0
  28. audiocraft/data/zip.py +8 -6
  29. audiocraft/environment.py +176 -0
  30. audiocraft/grids/__init__.py +6 -0
  31. audiocraft/grids/_base_explorers.py +80 -0
  32. audiocraft/grids/audiogen/__init__.py +6 -0
  33. audiocraft/grids/audiogen/audiogen_base_16khz.py +23 -0
  34. audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py +68 -0
  35. audiocraft/grids/compression/__init__.py +6 -0
  36. audiocraft/grids/compression/_explorers.py +55 -0
  37. audiocraft/grids/compression/debug.py +31 -0
  38. audiocraft/grids/compression/encodec_audiogen_16khz.py +29 -0
  39. audiocraft/grids/compression/encodec_base_24khz.py +28 -0
  40. audiocraft/grids/compression/encodec_musicgen_32khz.py +34 -0
  41. audiocraft/grids/diffusion/4_bands_base_32khz.py +27 -0
  42. audiocraft/grids/diffusion/__init__.py +6 -0
  43. audiocraft/grids/diffusion/_explorers.py +66 -0
  44. audiocraft/grids/musicgen/__init__.py +6 -0
  45. audiocraft/grids/musicgen/_explorers.py +93 -0
  46. audiocraft/grids/musicgen/musicgen_base_32khz.py +43 -0
  47. audiocraft/grids/musicgen/musicgen_base_cached_32khz.py +67 -0
  48. audiocraft/grids/musicgen/musicgen_clapemb_32khz.py +32 -0
  49. audiocraft/grids/musicgen/musicgen_melody_32khz.py +65 -0
  50. audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py +99 -0
.github/actions/audiocraft_build/action.yml CHANGED
@@ -21,6 +21,8 @@ runs:
21
  python3 -m venv env
22
  . env/bin/activate
23
  python -m pip install --upgrade pip
 
 
24
  pip install -e '.[dev]'
25
  - name: System Dependencies
26
  shell: bash
 
21
  python3 -m venv env
22
  . env/bin/activate
23
  python -m pip install --upgrade pip
24
+ pip install torch torchvision torchaudio
25
+ pip install xformers
26
  pip install -e '.[dev]'
27
  - name: System Dependencies
28
  shell: bash
.github/workflows/audiocraft_docs.yml CHANGED
@@ -23,9 +23,9 @@ jobs:
23
  - name: Make docs
24
  run: |
25
  . env/bin/activate
26
- make docs
27
- git add -f docs
28
- git commit -m docs
29
 
30
  - name: Push branch
31
  run: |
 
23
  - name: Make docs
24
  run: |
25
  . env/bin/activate
26
+ make api_docs
27
+ git add -f api_docs
28
+ git commit -m api_docs
29
 
30
  - name: Push branch
31
  run: |
.github/workflows/audiocraft_tests.yml CHANGED
@@ -12,6 +12,11 @@ jobs:
12
  steps:
13
  - uses: actions/checkout@v2
14
  - uses: ./.github/actions/audiocraft_build
15
- - run: |
 
16
  . env/bin/activate
17
  make tests
 
 
 
 
 
12
  steps:
13
  - uses: actions/checkout@v2
14
  - uses: ./.github/actions/audiocraft_build
15
+ - name: Run unit tests
16
+ run: |
17
  . env/bin/activate
18
  make tests
19
+ - name: Run integration tests
20
+ run: |
21
+ . env/bin/activate
22
+ make tests_integ
.gitignore CHANGED
@@ -35,7 +35,7 @@ wheels/
35
  .coverage
36
 
37
  # docs
38
- /docs
39
 
40
  # dotenv
41
  .env
@@ -46,6 +46,13 @@ wheels/
46
  venv/
47
  ENV/
48
 
 
 
 
 
 
 
 
49
  # personal notebooks & scripts
50
  */local_scripts
51
  */notes
 
35
  .coverage
36
 
37
  # docs
38
+ /api_docs
39
 
40
  # dotenv
41
  .env
 
46
  venv/
47
  ENV/
48
 
49
+ # egs with manifest files
50
+ egs/*
51
+ !egs/example
52
+ # local datasets
53
+ dataset/*
54
+ !dataset/example
55
+
56
  # personal notebooks & scripts
57
  */local_scripts
58
  */notes
CHANGELOG.md CHANGED
@@ -4,7 +4,37 @@ All notable changes to this project will be documented in this file.
4
 
5
  The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
6
 
7
- ## [0.0.2a] - TBD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  Improved demo, fixed top p (thanks @jnordberg).
10
 
 
4
 
5
  The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
6
 
7
+ ## [1.2.0a] - TBD
8
+
9
+ Adding stereo models.
10
+
11
+
12
+ ## [1.1.0] - 2023-11-06
13
+
14
+ Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons.
15
+
16
+ Fixed DAC support with non default number of codebooks.
17
+
18
+ Fixed bug when `two_step_cfg` was overriden when calling `generate()`.
19
+
20
+ Fixed samples being always prompted with audio, rather than having both prompted and unprompted.
21
+
22
+ **Backward incompatible change:** A `torch.no_grad` around the computation of the conditioning made its way in the public release.
23
+ The released models were trained without this. Those impact linear layers applied to the output of the T5 or melody conditioners.
24
+ We removed it, so you might need to retrain models.
25
+
26
+ **Backward incompatible change:** Fixing wrong sample rate in CLAP (WARNING if you trained model with CLAP before).
27
+
28
+ **Backward incompatible change:** Renamed VALLEPattern to CoarseFirstPattern, as it was wrongly named. Probably no one
29
+ retrained a model with this pattern, so hopefully this won't impact you!
30
+
31
+
32
+ ## [1.0.0] - 2023-09-07
33
+
34
+ Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
35
+ Added pretrained model for AudioGen and MultiBandDiffusion.
36
+
37
+ ## [0.0.2] - 2023-08-01
38
 
39
  Improved demo, fixed top p (thanks @jnordberg).
40
 
CONTRIBUTING.md CHANGED
@@ -1,11 +1,11 @@
1
- # Contributing to Audiocraft
2
 
3
  We want to make contributing to this project as easy and transparent as
4
  possible.
5
 
6
  ## Pull Requests
7
 
8
- Audiocraft is the implementation of a research paper.
9
  Therefore, we do not plan on accepting many pull requests for new features.
10
  We certainly welcome them for bug fixes.
11
 
 
1
+ # Contributing to AudioCraft
2
 
3
  We want to make contributing to this project as easy and transparent as
4
  possible.
5
 
6
  ## Pull Requests
7
 
8
+ AudioCraft is the implementation of a research paper.
9
  Therefore, we do not plan on accepting many pull requests for new features.
10
  We certainly welcome them for bug fixes.
11
 
LICENSE_weights CHANGED
@@ -1,157 +1,399 @@
1
- # Attribution-NonCommercial-NoDerivatives 4.0 International
2
-
3
- > *Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.*
4
- >
5
- > ### Using Creative Commons Public Licenses
6
- >
7
- > Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
8
- >
9
- > * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
10
- >
11
- > * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
12
-
13
- ## Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License
14
-
15
- By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
16
-
17
- ### Section 1 Definitions.
18
-
19
- a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
20
-
21
- b. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
22
-
23
- e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
24
-
25
- f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
26
-
27
- h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
28
-
29
- i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
30
-
31
- h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
32
-
33
- i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
34
-
35
- j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
36
-
37
- k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
38
-
39
- l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
40
-
41
- ### Section 2 Scope.
42
-
43
- a. ___License grant.___
44
-
45
- 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
46
-
47
- A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
48
-
49
- B. produce and reproduce, but not Share, Adapted Material for NonCommercial purposes only.
50
-
51
- 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
52
-
53
- 3. __Term.__ The term of this Public License is specified in Section 6(a).
54
-
55
- 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
56
-
57
- 5. __Downstream recipients.__
58
-
59
- A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
60
-
61
- B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
62
-
63
- 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
64
-
65
- b. ___Other rights.___
66
-
67
- 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
68
-
69
- 2. Patent and trademark rights are not licensed under this Public License.
70
-
71
- 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
72
-
73
- ### Section 3 License Conditions.
74
-
75
- Your exercise of the Licensed Rights is expressly made subject to the following conditions.
76
-
77
- a. ___Attribution.___
78
-
79
- 1. If You Share the Licensed Material, You must:
80
-
81
- A. retain the following if it is supplied by the Licensor with the Licensed Material:
82
-
83
- i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
84
-
85
- ii. a copyright notice;
86
-
87
- iii. a notice that refers to this Public License;
88
-
89
- iv. a notice that refers to the disclaimer of warranties;
90
-
91
- v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
92
-
93
- B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
94
-
95
- C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
96
-
97
- For the avoidance of doubt, You do not have permission under this Public License to Share Adapted Material.
98
-
99
- 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
100
-
101
- 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
102
-
103
- ### Section 4 Sui Generis Database Rights.
104
-
105
- Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
106
-
107
- a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only and provided You do not Share Adapted Material;
108
-
109
- b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
110
-
111
- c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
112
-
113
- For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
114
-
115
- ### Section 5 Disclaimer of Warranties and Limitation of Liability.
116
-
117
- a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
118
-
119
- b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
120
-
121
- c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
122
-
123
- ### Section 6 Term and Termination.
124
-
125
- a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
126
-
127
- b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
128
-
129
- 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
130
-
131
- 2. upon express reinstatement by the Licensor.
132
-
133
- For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
134
-
135
- c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
136
-
137
- d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
138
-
139
- ### Section 7 – Other Terms and Conditions.
140
-
141
- a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
142
-
143
- b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
144
-
145
- ### Section 8 Interpretation.
146
-
147
- a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
148
-
149
- b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
150
-
151
- c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
152
-
153
- d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
154
-
155
- > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
156
- >
157
- > Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+ Section 1 -- Definitions.
71
+
72
+ a. Adapted Material means material subject to Copyright and Similar
73
+ Rights that is derived from or based upon the Licensed Material
74
+ and in which the Licensed Material is translated, altered,
75
+ arranged, transformed, or otherwise modified in a manner requiring
76
+ permission under the Copyright and Similar Rights held by the
77
+ Licensor. For purposes of this Public License, where the Licensed
78
+ Material is a musical work, performance, or sound recording,
79
+ Adapted Material is always produced where the Licensed Material is
80
+ synched in timed relation with a moving image.
81
+
82
+ b. Adapter's License means the license You apply to Your Copyright
83
+ and Similar Rights in Your contributions to Adapted Material in
84
+ accordance with the terms and conditions of this Public License.
85
+
86
+ c. Copyright and Similar Rights means copyright and/or similar rights
87
+ closely related to copyright including, without limitation,
88
+ performance, broadcast, sound recording, and Sui Generis Database
89
+ Rights, without regard to how the rights are labeled or
90
+ categorized. For purposes of this Public License, the rights
91
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
92
+ Rights.
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. NonCommercial means not primarily intended for or directed towards
116
+ commercial advantage or monetary compensation. For purposes of
117
+ this Public License, the exchange of the Licensed Material for
118
+ other material subject to Copyright and Similar Rights by digital
119
+ file-sharing or similar means is NonCommercial provided there is
120
+ no payment of monetary compensation in connection with the
121
+ exchange.
122
+
123
+ j. Share means to provide material to the public by any means or
124
+ process that requires permission under the Licensed Rights, such
125
+ as reproduction, public display, public performance, distribution,
126
+ dissemination, communication, or importation, and to make material
127
+ available to the public including in ways that members of the
128
+ public may access the material from a place and at a time
129
+ individually chosen by them.
130
+
131
+ k. Sui Generis Database Rights means rights other than copyright
132
+ resulting from Directive 96/9/EC of the European Parliament and of
133
+ the Council of 11 March 1996 on the legal protection of databases,
134
+ as amended and/or succeeded, as well as other essentially
135
+ equivalent rights anywhere in the world.
136
+
137
+ l. You means the individual or entity exercising the Licensed Rights
138
+ under this Public License. Your has a corresponding meaning.
139
+
140
+ Section 2 -- Scope.
141
+
142
+ a. License grant.
143
+
144
+ 1. Subject to the terms and conditions of this Public License,
145
+ the Licensor hereby grants You a worldwide, royalty-free,
146
+ non-sublicensable, non-exclusive, irrevocable license to
147
+ exercise the Licensed Rights in the Licensed Material to:
148
+
149
+ a. reproduce and Share the Licensed Material, in whole or
150
+ in part, for NonCommercial purposes only; and
151
+
152
+ b. produce, reproduce, and Share Adapted Material for
153
+ NonCommercial purposes only.
154
+
155
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
156
+ Exceptions and Limitations apply to Your use, this Public
157
+ License does not apply, and You do not need to comply with
158
+ its terms and conditions.
159
+
160
+ 3. Term. The term of this Public License is specified in Section
161
+ 6(a).
162
+
163
+ 4. Media and formats; technical modifications allowed. The
164
+ Licensor authorizes You to exercise the Licensed Rights in
165
+ all media and formats whether now known or hereafter created,
166
+ and to make technical modifications necessary to do so. The
167
+ Licensor waives and/or agrees not to assert any right or
168
+ authority to forbid You from making technical modifications
169
+ necessary to exercise the Licensed Rights, including
170
+ technical modifications necessary to circumvent Effective
171
+ Technological Measures. For purposes of this Public License,
172
+ simply making modifications authorized by this Section 2(a)
173
+ (4) never produces Adapted Material.
174
+
175
+ 5. Downstream recipients.
176
+
177
+ a. Offer from the Licensor -- Licensed Material. Every
178
+ recipient of the Licensed Material automatically
179
+ receives an offer from the Licensor to exercise the
180
+ Licensed Rights under the terms and conditions of this
181
+ Public License.
182
+
183
+ b. No downstream restrictions. You may not offer or impose
184
+ any additional or different terms or conditions on, or
185
+ apply any Effective Technological Measures to, the
186
+ Licensed Material if doing so restricts exercise of the
187
+ Licensed Rights by any recipient of the Licensed
188
+ Material.
189
+
190
+ 6. No endorsement. Nothing in this Public License constitutes or
191
+ may be construed as permission to assert or imply that You
192
+ are, or that Your use of the Licensed Material is, connected
193
+ with, or sponsored, endorsed, or granted official status by,
194
+ the Licensor or others designated to receive attribution as
195
+ provided in Section 3(a)(1)(A)(i).
196
+
197
+ b. Other rights.
198
+
199
+ 1. Moral rights, such as the right of integrity, are not
200
+ licensed under this Public License, nor are publicity,
201
+ privacy, and/or other similar personality rights; however, to
202
+ the extent possible, the Licensor waives and/or agrees not to
203
+ assert any such rights held by the Licensor to the limited
204
+ extent necessary to allow You to exercise the Licensed
205
+ Rights, but not otherwise.
206
+
207
+ 2. Patent and trademark rights are not licensed under this
208
+ Public License.
209
+
210
+ 3. To the extent possible, the Licensor waives any right to
211
+ collect royalties from You for the exercise of the Licensed
212
+ Rights, whether directly or through a collecting society
213
+ under any voluntary or waivable statutory or compulsory
214
+ licensing scheme. In all other cases the Licensor expressly
215
+ reserves any right to collect such royalties, including when
216
+ the Licensed Material is used other than for NonCommercial
217
+ purposes.
218
+
219
+ Section 3 -- License Conditions.
220
+
221
+ Your exercise of the Licensed Rights is expressly made subject to the
222
+ following conditions.
223
+
224
+ a. Attribution.
225
+
226
+ 1. If You Share the Licensed Material (including in modified
227
+ form), You must:
228
+
229
+ a. retain the following if it is supplied by the Licensor
230
+ with the Licensed Material:
231
+
232
+ i. identification of the creator(s) of the Licensed
233
+ Material and any others designated to receive
234
+ attribution, in any reasonable manner requested by
235
+ the Licensor (including by pseudonym if
236
+ designated);
237
+
238
+ ii. a copyright notice;
239
+
240
+ iii. a notice that refers to this Public License;
241
+
242
+ iv. a notice that refers to the disclaimer of
243
+ warranties;
244
+
245
+ v. a URI or hyperlink to the Licensed Material to the
246
+ extent reasonably practicable;
247
+
248
+ b. indicate if You modified the Licensed Material and
249
+ retain an indication of any previous modifications; and
250
+
251
+ c. indicate the Licensed Material is licensed under this
252
+ Public License, and include the text of, or the URI or
253
+ hyperlink to, this Public License.
254
+
255
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
256
+ reasonable manner based on the medium, means, and context in
257
+ which You Share the Licensed Material. For example, it may be
258
+ reasonable to satisfy the conditions by providing a URI or
259
+ hyperlink to a resource that includes the required
260
+ information.
261
+
262
+ 3. If requested by the Licensor, You must remove any of the
263
+ information required by Section 3(a)(1)(A) to the extent
264
+ reasonably practicable.
265
+
266
+ 4. If You Share Adapted Material You produce, the Adapter's
267
+ License You apply must not prevent recipients of the Adapted
268
+ Material from complying with this Public License.
269
+
270
+ Section 4 -- Sui Generis Database Rights.
271
+
272
+ Where the Licensed Rights include Sui Generis Database Rights that
273
+ apply to Your use of the Licensed Material:
274
+
275
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276
+ to extract, reuse, reproduce, and Share all or a substantial
277
+ portion of the contents of the database for NonCommercial purposes
278
+ only;
279
+
280
+ b. if You include all or a substantial portion of the database
281
+ contents in a database in which You have Sui Generis Database
282
+ Rights, then the database in which You have Sui Generis Database
283
+ Rights (but not its individual contents) is Adapted Material; and
284
+
285
+ c. You must comply with the conditions in Section 3(a) if You Share
286
+ all or a substantial portion of the contents of the database.
287
+
288
+ For the avoidance of doubt, this Section 4 supplements and does not
289
+ replace Your obligations under this Public License where the Licensed
290
+ Rights include other Copyright and Similar Rights.
291
+
292
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
+
294
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
+
305
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
+
315
+ c. The disclaimer of warranties and limitation of liability provided
316
+ above shall be interpreted in a manner that, to the extent
317
+ possible, most closely approximates an absolute disclaimer and
318
+ waiver of all liability.
319
+
320
+ Section 6 -- Term and Termination.
321
+
322
+ a. This Public License applies for the term of the Copyright and
323
+ Similar Rights licensed here. However, if You fail to comply with
324
+ this Public License, then Your rights under this Public License
325
+ terminate automatically.
326
+
327
+ b. Where Your right to use the Licensed Material has terminated under
328
+ Section 6(a), it reinstates:
329
+
330
+ 1. automatically as of the date the violation is cured, provided
331
+ it is cured within 30 days of Your discovery of the
332
+ violation; or
333
+
334
+ 2. upon express reinstatement by the Licensor.
335
+
336
+ For the avoidance of doubt, this Section 6(b) does not affect any
337
+ right the Licensor may have to seek remedies for Your violations
338
+ of this Public License.
339
+
340
+ c. For the avoidance of doubt, the Licensor may also offer the
341
+ Licensed Material under separate terms or conditions or stop
342
+ distributing the Licensed Material at any time; however, doing so
343
+ will not terminate this Public License.
344
+
345
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346
+ License.
347
+
348
+ Section 7 -- Other Terms and Conditions.
349
+
350
+ a. The Licensor shall not be bound by any additional or different
351
+ terms or conditions communicated by You unless expressly agreed.
352
+
353
+ b. Any arrangements, understandings, or agreements regarding the
354
+ Licensed Material not stated herein are separate from and
355
+ independent of the terms and conditions of this Public License.
356
+
357
+ Section 8 -- Interpretation.
358
+
359
+ a. For the avoidance of doubt, this Public License does not, and
360
+ shall not be interpreted to, reduce, limit, restrict, or impose
361
+ conditions on any use of the Licensed Material that could lawfully
362
+ be made without permission under this Public License.
363
+
364
+ b. To the extent possible, if any provision of this Public License is
365
+ deemed unenforceable, it shall be automatically reformed to the
366
+ minimum extent necessary to make it enforceable. If the provision
367
+ cannot be reformed, it shall be severed from this Public License
368
+ without affecting the enforceability of the remaining terms and
369
+ conditions.
370
+
371
+ c. No term or condition of this Public License will be waived and no
372
+ failure to comply consented to unless expressly agreed to by the
373
+ Licensor.
374
+
375
+ d. Nothing in this Public License constitutes or may be interpreted
376
+ as a limitation upon, or waiver of, any privileges and immunities
377
+ that apply to the Licensor or You, including from the legal
378
+ processes of any jurisdiction or authority.
379
+
380
+ =======================================================================
381
+
382
+ Creative Commons is not a party to its public
383
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
384
+ its public licenses to material it publishes and in those instances
385
+ will be considered the “Licensor.” The text of the Creative Commons
386
+ public licenses is dedicated to the public domain under the CC0 Public
387
+ Domain Dedication. Except for the limited purpose of indicating that
388
+ material is shared under a Creative Commons public license or as
389
+ otherwise permitted by the Creative Commons policies published at
390
+ creativecommons.org/policies, Creative Commons does not authorize the
391
+ use of the trademark "Creative Commons" or any other trademark or logo
392
+ of Creative Commons without its prior written consent including,
393
+ without limitation, in connection with any unauthorized modifications
394
+ to any of its public licenses or any other arrangements,
395
+ understandings, or agreements concerning use of licensed material. For
396
+ the avoidance of doubt, this paragraph does not form part of the
397
+ public licenses.
398
+
399
+ Creative Commons may be contacted at creativecommons.org.
MANIFEST.in CHANGED
@@ -6,3 +6,10 @@ include *.ini
6
  include requirements.txt
7
  include audiocraft/py.typed
8
  include assets/*.mp3
 
 
 
 
 
 
 
 
6
  include requirements.txt
7
  include audiocraft/py.typed
8
  include assets/*.mp3
9
+ include datasets/*.mp3
10
+ recursive-include config *.yaml
11
+ recursive-include demos *.py
12
+ recursive-include demos *.ipynb
13
+ recursive-include scripts *.py
14
+ recursive-include model_cards *.md
15
+ recursive-include docs *.md
Makefile CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  default: linter tests
2
 
3
  install:
@@ -10,12 +22,19 @@ linter:
10
 
11
  tests:
12
  coverage run -m pytest tests
13
- coverage report --include 'audiocraft/*'
 
 
 
 
 
 
 
14
 
15
- docs:
16
- pdoc3 --html -o docs -f audiocraft
17
 
18
  dist:
19
  python setup.py sdist
20
 
21
- .PHONY: linter tests docs dist
 
1
+ INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \
2
+ dataset.train.num_samples=10 dataset.valid.num_samples=10 \
3
+ dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \
4
+ logging.level=DEBUG
5
+ INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 5091833e
6
+ INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
7
+ transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e
8
+ INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
9
+ transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e
10
+ INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \
11
+ checkpoint.save_last=false # Using compression model from 616d7b3c
12
+
13
  default: linter tests
14
 
15
  install:
 
22
 
23
  tests:
24
  coverage run -m pytest tests
25
+ coverage report
26
+
27
+ tests_integ:
28
+ $(INTEG_COMPRESSION)
29
+ $(INTEG_MBD)
30
+ $(INTEG_MUSICGEN)
31
+ $(INTEG_AUDIOGEN)
32
+
33
 
34
+ api_docs:
35
+ pdoc3 --html -o api_docs -f audiocraft
36
 
37
  dist:
38
  python setup.py sdist
39
 
40
+ .PHONY: linter tests api_docs dist
README.md CHANGED
@@ -5,7 +5,7 @@ tags:
5
  - "music generation"
6
  - "language models"
7
  - "LLMs"
8
- app_file: "app.py"
9
  emoji: 🎵
10
  colorFrom: gray
11
  colorTo: blue
@@ -14,33 +14,17 @@ sdk_version: 3.34.0
14
  pinned: true
15
  license: "cc-by-nc-4.0"
16
  ---
17
- # Audiocraft
18
  ![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
19
  ![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg)
20
  ![tests badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_tests/badge.svg)
21
 
22
- Audiocraft is a PyTorch library for deep learning research on audio generation. At the moment, it contains the code for MusicGen, a state-of-the-art controllable text-to-music model.
 
23
 
24
- ## MusicGen
25
-
26
- Audiocraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv]. MusicGen is a single stage auto-regressive
27
- Transformer model trained over a 32kHz <a href="https://github.com/facebookresearch/encodec">EnCodec tokenizer</a> with 4 codebooks sampled at 50 Hz. Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't require a self-supervised semantic representation, and it generates
28
- all 4 codebooks in one pass. By introducing a small delay between the codebooks, we show we can predict
29
- them in parallel, thus having only 50 auto-regressive steps per second of audio.
30
- Check out our [sample page][musicgen_samples] or test the available demo!
31
-
32
- <a target="_blank" href="https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing">
33
- <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
34
- </a>
35
- <a target="_blank" href="https://huggingface.co/spaces/facebook/MusicGen">
36
- <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg" alt="Open in HugginFace"/>
37
- </a>
38
- <br>
39
-
40
- We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data.
41
 
42
  ## Installation
43
- Audiocraft requires Python 3.9, PyTorch 2.0.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following:
44
 
45
  ```shell
46
  # Best to make sure you have torch installed first, in particular before installing xformers.
@@ -49,92 +33,68 @@ pip install 'torch>=2.0'
49
  # Then proceed to one of the following
50
  pip install -U audiocraft # stable release
51
  pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
52
- pip install -e . # or if you cloned the repo locally
53
  ```
54
 
55
- ## Usage
56
- We offer a number of way to interact with MusicGen:
57
- 1. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support).
58
- 2. You can run the Gradio demo in Colab: [colab notebook](https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing).
59
- 3. You can use the gradio demo locally by running `python app.py`.
60
- 4. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally (if you have a GPU).
61
- 5. Finally, checkout [@camenduru Colab page](https://github.com/camenduru/MusicGen-colab) which is regularly
62
- updated with contributions from @camenduru and the community.
63
-
64
- ## API
65
-
66
- We provide a simple API and 4 pre-trained models. The pre trained models are:
67
- - `small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small)
68
- - `medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium)
69
- - `melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody)
70
- - `large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large)
71
-
72
- We observe the best trade-off between quality and compute with the `medium` or `melody` model.
73
- In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
74
- GPUs will be able to generate short sequences, or longer sequences with the `small` model.
75
-
76
- **Note**: Please make sure to have [ffmpeg](https://ffmpeg.org/download.html) installed when using newer version of `torchaudio`.
77
- You can install it with:
78
- ```
79
- apt-get install ffmpeg
80
  ```
81
 
82
- See after a quick example for using the API.
83
 
84
- ```python
85
- import torchaudio
86
- from audiocraft.models import MusicGen
87
- from audiocraft.data.audio import audio_write
 
88
 
89
- model = MusicGen.get_pretrained('melody')
90
- model.set_generation_params(duration=8) # generate 8 seconds.
91
- wav = model.generate_unconditional(4) # generates 4 unconditional audio samples
92
- descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
93
- wav = model.generate(descriptions) # generates 3 samples.
94
 
95
- melody, sr = torchaudio.load('./assets/bach.mp3')
96
- # generates using the melody from the given audio and the provided descriptions.
97
- wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr)
98
 
99
- for idx, one_wav in enumerate(wav):
100
- # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
101
- audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
102
- ```
103
 
104
 
105
- ## Model Card
106
 
107
- See [the model card page](./MODEL_CARD.md).
108
 
109
- ## FAQ
110
 
111
- #### Will the training code be released?
112
 
113
- Yes. We will soon release the training code for MusicGen and EnCodec.
114
 
 
115
 
116
- #### I need help on Windows
117
 
118
- @FurkanGozukara made a complete tutorial for [Audiocraft/MusicGen on Windows](https://youtu.be/v-YpvPkhdO4)
 
 
119
 
120
- #### I need help for running the demo on Colab
121
 
122
- Check [@camenduru tutorial on Youtube](https://www.youtube.com/watch?v=EGfxuTy9Eeo).
 
 
123
 
124
 
125
  ## Citation
 
 
126
  ```
127
  @article{copet2023simple,
128
- title={Simple and Controllable Music Generation},
129
- author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
130
- year={2023},
131
- journal={arXiv preprint arXiv:2306.05284},
132
  }
133
  ```
134
 
135
- ## License
136
- * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
137
- * The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
138
-
139
- [arxiv]: https://arxiv.org/abs/2306.05284
140
- [musicgen_samples]: https://ai.honu.io/papers/musicgen/
 
5
  - "music generation"
6
  - "language models"
7
  - "LLMs"
8
+ app_file: "demos/musicgen_app.py"
9
  emoji: 🎵
10
  colorFrom: gray
11
  colorTo: blue
 
14
  pinned: true
15
  license: "cc-by-nc-4.0"
16
  ---
17
+ # AudioCraft
18
  ![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
19
  ![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg)
20
  ![tests badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_tests/badge.svg)
21
 
22
+ AudioCraft is a PyTorch library for deep learning research on audio generation. AudioCraft contains inference and training code
23
+ for two state-of-the-art AI generative models producing high-quality audio: AudioGen and MusicGen.
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  ## Installation
27
+ AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can run the following:
28
 
29
  ```shell
30
  # Best to make sure you have torch installed first, in particular before installing xformers.
 
33
  # Then proceed to one of the following
34
  pip install -U audiocraft # stable release
35
  pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
36
+ pip install -e . # or if you cloned the repo locally (mandatory if you want to train).
37
  ```
38
 
39
+ We also recommend having `ffmpeg` installed, either through your system or Anaconda:
40
+ ```bash
41
+ sudo apt-get install ffmpeg
42
+ # Or if you are using Anaconda or Miniconda
43
+ conda install "ffmpeg<5" -c conda-forge
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ```
45
 
46
+ ## Models
47
 
48
+ At the moment, AudioCraft contains the training code and inference code for:
49
+ * [MusicGen](./docs/MUSICGEN.md): A state-of-the-art controllable text-to-music model.
50
+ * [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model.
51
+ * [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec.
52
+ * [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion.
53
 
54
+ ## Training code
 
 
 
 
55
 
56
+ AudioCraft contains PyTorch components for deep learning research in audio and training pipelines for the developed models.
57
+ For a general introduction of AudioCraft design principles and instructions to develop your own training pipeline, refer to
58
+ the [AudioCraft training documentation](./docs/TRAINING.md).
59
 
60
+ For reproducing existing work and using the developed training pipelines, refer to the instructions for each specific model
61
+ that provides pointers to configuration, example grids and model/task-specific information and FAQ.
 
 
62
 
63
 
64
+ ## API documentation
65
 
66
+ We provide some [API documentation](https://facebookresearch.github.io/audiocraft/api_docs/audiocraft/index.html) for AudioCraft.
67
 
 
68
 
69
+ ## FAQ
70
 
71
+ #### Is the training code available?
72
 
73
+ Yes! We provide the training code for [EnCodec](./docs/ENCODEC.md), [MusicGen](./docs/MUSICGEN.md) and [Multi Band Diffusion](./docs/MBD.md).
74
 
75
+ #### Where are the models stored?
76
 
77
+ Hugging Face stored the model in a specific location, which can be overriden by setting the `AUDIOCRAFT_CACHE_DIR` environment variable for the AudioCraft models.
78
+ In order to change the cache location of the other Hugging Face models, please check out the [Hugging Face Transformers documentation for the cache setup](https://huggingface.co/docs/transformers/installation#cache-setup).
79
+ Finally, if you use a model that relies on Demucs (e.g. `musicgen-melody`) and want to change the download location for Demucs, refer to the [Torch Hub documentation](https://pytorch.org/docs/stable/hub.html#where-are-my-downloaded-models-saved).
80
 
 
81
 
82
+ ## License
83
+ * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
84
+ * The models weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
85
 
86
 
87
  ## Citation
88
+
89
+ For the general framework of AudioCraft, please cite the following.
90
  ```
91
  @article{copet2023simple,
92
+ title={Simple and Controllable Music Generation},
93
+ author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
94
+ year={2023},
95
+ journal={arXiv preprint arXiv:2306.05284},
96
  }
97
  ```
98
 
99
+ When referring to a specific model, please cite as mentioned in the model specific README, e.g
100
+ [./docs/MUSICGEN.md](./docs/MUSICGEN.md), [./docs/AUDIOGEN.md](./docs/AUDIOGEN.md), etc.
 
 
 
 
assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 ADDED
Binary file (15.2 kB). View file
 
assets/sirens_and_a_humming_engine_approach_and_pass.mp3 ADDED
Binary file (15.2 kB). View file
 
audiocraft/__init__.py CHANGED
@@ -3,8 +3,24 @@
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # flake8: noqa
8
  from . import data, modules, models
9
 
10
- __version__ = '0.0.2a2'
 
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
+ """
7
+ AudioCraft is a general framework for training audio generative models.
8
+ At the moment we provide the training code for:
9
+
10
+ - [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
11
+ text-to-music and melody+text autoregressive generative model.
12
+ For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
13
+ `audiocraft.models.musicgen.MusicGen`.
14
+ - [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
15
+ text-to-general-audio generative model.
16
+ - [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
17
+ neural audio codec which provides an excellent tokenizer for autoregressive language models.
18
+ See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
19
+ - [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
20
+ improves the perceived quality and reduces the artifacts coming from adversarial decoders.
21
+ """
22
 
23
  # flake8: noqa
24
  from . import data, modules, models
25
 
26
+ __version__ = '1.1.0'
audiocraft/adversarial/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Adversarial losses and discriminator architectures."""
7
+
8
+ # flake8: noqa
9
+ from .discriminators import (
10
+ MultiPeriodDiscriminator,
11
+ MultiScaleDiscriminator,
12
+ MultiScaleSTFTDiscriminator
13
+ )
14
+ from .losses import (
15
+ AdversarialLoss,
16
+ AdvLossType,
17
+ get_adv_criterion,
18
+ get_fake_criterion,
19
+ get_real_criterion,
20
+ FeatLossType,
21
+ FeatureMatchingLoss
22
+ )
audiocraft/adversarial/discriminators/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .mpd import MultiPeriodDiscriminator
9
+ from .msd import MultiScaleDiscriminator
10
+ from .msstftd import MultiScaleSTFTDiscriminator
audiocraft/adversarial/discriminators/base.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+ import typing as tp
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+
14
+ FeatureMapType = tp.List[torch.Tensor]
15
+ LogitsType = torch.Tensor
16
+ MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
17
+
18
+
19
+ class MultiDiscriminator(ABC, nn.Module):
20
+ """Base implementation for discriminators composed of sub-discriminators acting at different scales.
21
+ """
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ @abstractmethod
26
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
27
+ ...
28
+
29
+ @property
30
+ @abstractmethod
31
+ def num_discriminators(self) -> int:
32
+ """Number of discriminators.
33
+ """
34
+ ...
audiocraft/adversarial/discriminators/mpd.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from ...modules import NormConv2d
14
+ from .base import MultiDiscriminator, MultiDiscriminatorOutputType
15
+
16
+
17
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
18
+ return int((kernel_size * dilation - dilation) / 2)
19
+
20
+
21
+ class PeriodDiscriminator(nn.Module):
22
+ """Period sub-discriminator.
23
+
24
+ Args:
25
+ period (int): Period between samples of audio.
26
+ in_channels (int): Number of input channels.
27
+ out_channels (int): Number of output channels.
28
+ n_layers (int): Number of convolutional layers.
29
+ kernel_sizes (list of int): Kernel sizes for convolutions.
30
+ stride (int): Stride for convolutions.
31
+ filters (int): Initial number of filters in convolutions.
32
+ filters_scale (int): Multiplier of number of filters as we increase depth.
33
+ max_filters (int): Maximum number of filters.
34
+ norm (str): Normalization method.
35
+ activation (str): Activation function.
36
+ activation_params (dict): Parameters to provide to the activation function.
37
+ """
38
+ def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
39
+ n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
40
+ filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
41
+ norm: str = 'weight_norm', activation: str = 'LeakyReLU',
42
+ activation_params: dict = {'negative_slope': 0.2}):
43
+ super().__init__()
44
+ self.period = period
45
+ self.n_layers = n_layers
46
+ self.activation = getattr(torch.nn, activation)(**activation_params)
47
+ self.convs = nn.ModuleList()
48
+ in_chs = in_channels
49
+ for i in range(self.n_layers):
50
+ out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
51
+ eff_stride = 1 if i == self.n_layers - 1 else stride
52
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
53
+ padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
54
+ in_chs = out_chs
55
+ self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
56
+ padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
57
+
58
+ def forward(self, x: torch.Tensor):
59
+ fmap = []
60
+ # 1d to 2d
61
+ b, c, t = x.shape
62
+ if t % self.period != 0: # pad first
63
+ n_pad = self.period - (t % self.period)
64
+ x = F.pad(x, (0, n_pad), 'reflect')
65
+ t = t + n_pad
66
+ x = x.view(b, c, t // self.period, self.period)
67
+
68
+ for conv in self.convs:
69
+ x = conv(x)
70
+ x = self.activation(x)
71
+ fmap.append(x)
72
+ x = self.conv_post(x)
73
+ fmap.append(x)
74
+ # x = torch.flatten(x, 1, -1)
75
+
76
+ return x, fmap
77
+
78
+
79
+ class MultiPeriodDiscriminator(MultiDiscriminator):
80
+ """Multi-Period (MPD) Discriminator.
81
+
82
+ Args:
83
+ in_channels (int): Number of input channels.
84
+ out_channels (int): Number of output channels.
85
+ periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
86
+ **kwargs: Additional args for `PeriodDiscriminator`
87
+ """
88
+ def __init__(self, in_channels: int = 1, out_channels: int = 1,
89
+ periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
90
+ super().__init__()
91
+ self.discriminators = nn.ModuleList([
92
+ PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
93
+ ])
94
+
95
+ @property
96
+ def num_discriminators(self):
97
+ return len(self.discriminators)
98
+
99
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
100
+ logits = []
101
+ fmaps = []
102
+ for disc in self.discriminators:
103
+ logit, fmap = disc(x)
104
+ logits.append(logit)
105
+ fmaps.append(fmap)
106
+ return logits, fmaps
audiocraft/adversarial/discriminators/msd.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from ...modules import NormConv1d
14
+ from .base import MultiDiscriminator, MultiDiscriminatorOutputType
15
+
16
+
17
+ class ScaleDiscriminator(nn.Module):
18
+ """Waveform sub-discriminator.
19
+
20
+ Args:
21
+ in_channels (int): Number of input channels.
22
+ out_channels (int): Number of output channels.
23
+ kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
24
+ filters (int): Number of initial filters for convolutions.
25
+ max_filters (int): Maximum number of filters.
26
+ downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
27
+ inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
28
+ groups (Sequence[int] or None): Groups for inner convolutions.
29
+ strides (Sequence[int] or None): Strides for inner convolutions.
30
+ paddings (Sequence[int] or None): Paddings for inner convolutions.
31
+ norm (str): Normalization method.
32
+ activation (str): Activation function.
33
+ activation_params (dict): Parameters to provide to the activation function.
34
+ pad (str): Padding for initial convolution.
35
+ pad_params (dict): Parameters to provide to the padding module.
36
+ """
37
+ def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
38
+ filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
39
+ inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
40
+ strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
41
+ norm: str = 'weight_norm', activation: str = 'LeakyReLU',
42
+ activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
43
+ pad_params: dict = {}):
44
+ super().__init__()
45
+ assert len(kernel_sizes) == 2
46
+ assert kernel_sizes[0] % 2 == 1
47
+ assert kernel_sizes[1] % 2 == 1
48
+ assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
49
+ assert (groups is None or len(groups) == len(downsample_scales))
50
+ assert (strides is None or len(strides) == len(downsample_scales))
51
+ assert (paddings is None or len(paddings) == len(downsample_scales))
52
+ self.activation = getattr(torch.nn, activation)(**activation_params)
53
+ self.convs = nn.ModuleList()
54
+ self.convs.append(
55
+ nn.Sequential(
56
+ getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
57
+ NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
58
+ )
59
+ )
60
+
61
+ in_chs = filters
62
+ for i, downsample_scale in enumerate(downsample_scales):
63
+ out_chs = min(in_chs * downsample_scale, max_filters)
64
+ default_kernel_size = downsample_scale * 10 + 1
65
+ default_stride = downsample_scale
66
+ default_padding = (default_kernel_size - 1) // 2
67
+ default_groups = in_chs // 4
68
+ self.convs.append(
69
+ NormConv1d(in_chs, out_chs,
70
+ kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
71
+ stride=strides[i] if strides else default_stride,
72
+ groups=groups[i] if groups else default_groups,
73
+ padding=paddings[i] if paddings else default_padding,
74
+ norm=norm))
75
+ in_chs = out_chs
76
+
77
+ out_chs = min(in_chs * 2, max_filters)
78
+ self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
79
+ padding=(kernel_sizes[0] - 1) // 2, norm=norm))
80
+ self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
81
+ padding=(kernel_sizes[1] - 1) // 2, norm=norm)
82
+
83
+ def forward(self, x: torch.Tensor):
84
+ fmap = []
85
+ for layer in self.convs:
86
+ x = layer(x)
87
+ x = self.activation(x)
88
+ fmap.append(x)
89
+ x = self.conv_post(x)
90
+ fmap.append(x)
91
+ # x = torch.flatten(x, 1, -1)
92
+ return x, fmap
93
+
94
+
95
+ class MultiScaleDiscriminator(MultiDiscriminator):
96
+ """Multi-Scale (MSD) Discriminator,
97
+
98
+ Args:
99
+ in_channels (int): Number of input channels.
100
+ out_channels (int): Number of output channels.
101
+ downsample_factor (int): Downsampling factor between the different scales.
102
+ scale_norms (Sequence[str]): Normalization for each sub-discriminator.
103
+ **kwargs: Additional args for ScaleDiscriminator.
104
+ """
105
+ def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
106
+ scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
107
+ super().__init__()
108
+ self.discriminators = nn.ModuleList([
109
+ ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
110
+ ])
111
+ self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
112
+
113
+ @property
114
+ def num_discriminators(self):
115
+ return len(self.discriminators)
116
+
117
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
118
+ logits = []
119
+ fmaps = []
120
+ for i, disc in enumerate(self.discriminators):
121
+ if i != 0:
122
+ self.downsample(x)
123
+ logit, fmap = disc(x)
124
+ logits.append(logit)
125
+ fmaps.append(fmap)
126
+ return logits, fmaps
audiocraft/adversarial/discriminators/msstftd.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import torchaudio
10
+ import torch
11
+ from torch import nn
12
+ from einops import rearrange
13
+
14
+ from ...modules import NormConv2d
15
+ from .base import MultiDiscriminator, MultiDiscriminatorOutputType
16
+
17
+
18
+ def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
19
+ return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
20
+
21
+
22
+ class DiscriminatorSTFT(nn.Module):
23
+ """STFT sub-discriminator.
24
+
25
+ Args:
26
+ filters (int): Number of filters in convolutions.
27
+ in_channels (int): Number of input channels.
28
+ out_channels (int): Number of output channels.
29
+ n_fft (int): Size of FFT for each scale.
30
+ hop_length (int): Length of hop between STFT windows for each scale.
31
+ kernel_size (tuple of int): Inner Conv2d kernel sizes.
32
+ stride (tuple of int): Inner Conv2d strides.
33
+ dilations (list of int): Inner Conv2d dilation on the time dimension.
34
+ win_length (int): Window size for each scale.
35
+ normalized (bool): Whether to normalize by magnitude after stft.
36
+ norm (str): Normalization method.
37
+ activation (str): Activation function.
38
+ activation_params (dict): Parameters to provide to the activation function.
39
+ growth (int): Growth factor for the filters.
40
+ """
41
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
42
+ n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
43
+ filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
44
+ stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
45
+ activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
46
+ super().__init__()
47
+ assert len(kernel_size) == 2
48
+ assert len(stride) == 2
49
+ self.filters = filters
50
+ self.in_channels = in_channels
51
+ self.out_channels = out_channels
52
+ self.n_fft = n_fft
53
+ self.hop_length = hop_length
54
+ self.win_length = win_length
55
+ self.normalized = normalized
56
+ self.activation = getattr(torch.nn, activation)(**activation_params)
57
+ self.spec_transform = torchaudio.transforms.Spectrogram(
58
+ n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
59
+ normalized=self.normalized, center=False, pad_mode=None, power=None)
60
+ spec_channels = 2 * self.in_channels
61
+ self.convs = nn.ModuleList()
62
+ self.convs.append(
63
+ NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
64
+ )
65
+ in_chs = min(filters_scale * self.filters, max_filters)
66
+ for i, dilation in enumerate(dilations):
67
+ out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
68
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
69
+ dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
70
+ norm=norm))
71
+ in_chs = out_chs
72
+ out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
73
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
74
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
75
+ norm=norm))
76
+ self.conv_post = NormConv2d(out_chs, self.out_channels,
77
+ kernel_size=(kernel_size[0], kernel_size[0]),
78
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
79
+ norm=norm)
80
+
81
+ def forward(self, x: torch.Tensor):
82
+ fmap = []
83
+ z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
84
+ z = torch.cat([z.real, z.imag], dim=1)
85
+ z = rearrange(z, 'b c w t -> b c t w')
86
+ for i, layer in enumerate(self.convs):
87
+ z = layer(z)
88
+ z = self.activation(z)
89
+ fmap.append(z)
90
+ z = self.conv_post(z)
91
+ return z, fmap
92
+
93
+
94
+ class MultiScaleSTFTDiscriminator(MultiDiscriminator):
95
+ """Multi-Scale STFT (MS-STFT) discriminator.
96
+
97
+ Args:
98
+ filters (int): Number of filters in convolutions.
99
+ in_channels (int): Number of input channels.
100
+ out_channels (int): Number of output channels.
101
+ sep_channels (bool): Separate channels to distinct samples for stereo support.
102
+ n_ffts (Sequence[int]): Size of FFT for each scale.
103
+ hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
104
+ win_lengths (Sequence[int]): Window size for each scale.
105
+ **kwargs: Additional args for STFTDiscriminator.
106
+ """
107
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
108
+ n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
109
+ win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
110
+ super().__init__()
111
+ assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
112
+ self.sep_channels = sep_channels
113
+ self.discriminators = nn.ModuleList([
114
+ DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
115
+ n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
116
+ for i in range(len(n_ffts))
117
+ ])
118
+
119
+ @property
120
+ def num_discriminators(self):
121
+ return len(self.discriminators)
122
+
123
+ def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
124
+ B, C, T = x.shape
125
+ return x.view(-1, 1, T)
126
+
127
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
128
+ logits = []
129
+ fmaps = []
130
+ for disc in self.discriminators:
131
+ logit, fmap = disc(x)
132
+ logits.append(logit)
133
+ fmaps.append(fmap)
134
+ return logits, fmaps
audiocraft/adversarial/losses.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility module to handle adversarial losses without requiring to mess up the main training loop.
9
+ """
10
+
11
+ import typing as tp
12
+
13
+ import flashy
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
20
+
21
+
22
+ AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
23
+ FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
24
+
25
+
26
+ class AdversarialLoss(nn.Module):
27
+ """Adversary training wrapper.
28
+
29
+ Args:
30
+ adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
31
+ We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
32
+ where the first item is a list of logits and the second item is a list of feature maps.
33
+ optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
34
+ loss (AdvLossType): Loss function for generator training.
35
+ loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
36
+ loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
37
+ loss_feat (FeatLossType): Feature matching loss function for generator training.
38
+ normalize (bool): Whether to normalize by number of sub-discriminators.
39
+
40
+ Example of usage:
41
+ adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
42
+ for real in loader:
43
+ noise = torch.randn(...)
44
+ fake = model(noise)
45
+ adv_loss.train_adv(fake, real)
46
+ loss, _ = adv_loss(fake, real)
47
+ loss.backward()
48
+ """
49
+ def __init__(self,
50
+ adversary: nn.Module,
51
+ optimizer: torch.optim.Optimizer,
52
+ loss: AdvLossType,
53
+ loss_real: AdvLossType,
54
+ loss_fake: AdvLossType,
55
+ loss_feat: tp.Optional[FeatLossType] = None,
56
+ normalize: bool = True):
57
+ super().__init__()
58
+ self.adversary: nn.Module = adversary
59
+ flashy.distrib.broadcast_model(self.adversary)
60
+ self.optimizer = optimizer
61
+ self.loss = loss
62
+ self.loss_real = loss_real
63
+ self.loss_fake = loss_fake
64
+ self.loss_feat = loss_feat
65
+ self.normalize = normalize
66
+
67
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
68
+ # Add the optimizer state dict inside our own.
69
+ super()._save_to_state_dict(destination, prefix, keep_vars)
70
+ destination[prefix + 'optimizer'] = self.optimizer.state_dict()
71
+ return destination
72
+
73
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
74
+ # Load optimizer state.
75
+ self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
76
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
77
+
78
+ def get_adversary_pred(self, x):
79
+ """Run adversary model, validating expected output format."""
80
+ logits, fmaps = self.adversary(x)
81
+ assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
82
+ f'Expecting a list of tensors as logits but {type(logits)} found.'
83
+ assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
84
+ for fmap in fmaps:
85
+ assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
86
+ f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
87
+ return logits, fmaps
88
+
89
+ def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
90
+ """Train the adversary with the given fake and real example.
91
+
92
+ We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
93
+ The first item being the logits and second item being a list of feature maps for each sub-discriminator.
94
+
95
+ This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
96
+ and call the optimizer.
97
+ """
98
+ loss = torch.tensor(0., device=fake.device)
99
+ all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
100
+ all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
101
+ n_sub_adversaries = len(all_logits_fake_is_fake)
102
+ for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
103
+ loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
104
+
105
+ if self.normalize:
106
+ loss /= n_sub_adversaries
107
+
108
+ self.optimizer.zero_grad()
109
+ with flashy.distrib.eager_sync_model(self.adversary):
110
+ loss.backward()
111
+ self.optimizer.step()
112
+
113
+ return loss
114
+
115
+ def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
116
+ """Return the loss for the generator, i.e. trying to fool the adversary,
117
+ and feature matching loss if provided.
118
+ """
119
+ adv = torch.tensor(0., device=fake.device)
120
+ feat = torch.tensor(0., device=fake.device)
121
+ with flashy.utils.readonly(self.adversary):
122
+ all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
123
+ all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
124
+ n_sub_adversaries = len(all_logits_fake_is_fake)
125
+ for logit_fake_is_fake in all_logits_fake_is_fake:
126
+ adv += self.loss(logit_fake_is_fake)
127
+ if self.loss_feat:
128
+ for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
129
+ feat += self.loss_feat(fmap_fake, fmap_real)
130
+
131
+ if self.normalize:
132
+ adv /= n_sub_adversaries
133
+ feat /= n_sub_adversaries
134
+
135
+ return adv, feat
136
+
137
+
138
+ def get_adv_criterion(loss_type: str) -> tp.Callable:
139
+ assert loss_type in ADVERSARIAL_LOSSES
140
+ if loss_type == 'mse':
141
+ return mse_loss
142
+ elif loss_type == 'hinge':
143
+ return hinge_loss
144
+ elif loss_type == 'hinge2':
145
+ return hinge2_loss
146
+ raise ValueError('Unsupported loss')
147
+
148
+
149
+ def get_fake_criterion(loss_type: str) -> tp.Callable:
150
+ assert loss_type in ADVERSARIAL_LOSSES
151
+ if loss_type == 'mse':
152
+ return mse_fake_loss
153
+ elif loss_type in ['hinge', 'hinge2']:
154
+ return hinge_fake_loss
155
+ raise ValueError('Unsupported loss')
156
+
157
+
158
+ def get_real_criterion(loss_type: str) -> tp.Callable:
159
+ assert loss_type in ADVERSARIAL_LOSSES
160
+ if loss_type == 'mse':
161
+ return mse_real_loss
162
+ elif loss_type in ['hinge', 'hinge2']:
163
+ return hinge_real_loss
164
+ raise ValueError('Unsupported loss')
165
+
166
+
167
+ def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
168
+ return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
169
+
170
+
171
+ def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
172
+ return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
173
+
174
+
175
+ def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
176
+ return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
177
+
178
+
179
+ def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
180
+ return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
181
+
182
+
183
+ def mse_loss(x: torch.Tensor) -> torch.Tensor:
184
+ if x.numel() == 0:
185
+ return torch.tensor([0.0], device=x.device)
186
+ return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
187
+
188
+
189
+ def hinge_loss(x: torch.Tensor) -> torch.Tensor:
190
+ if x.numel() == 0:
191
+ return torch.tensor([0.0], device=x.device)
192
+ return -x.mean()
193
+
194
+
195
+ def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
196
+ if x.numel() == 0:
197
+ return torch.tensor([0.0])
198
+ return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
199
+
200
+
201
+ class FeatureMatchingLoss(nn.Module):
202
+ """Feature matching loss for adversarial training.
203
+
204
+ Args:
205
+ loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
206
+ normalize (bool): Whether to normalize the loss.
207
+ by number of feature maps.
208
+ """
209
+ def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
210
+ super().__init__()
211
+ self.loss = loss
212
+ self.normalize = normalize
213
+
214
+ def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
215
+ assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
216
+ feat_loss = torch.tensor(0., device=fmap_fake[0].device)
217
+ feat_scale = torch.tensor(0., device=fmap_fake[0].device)
218
+ n_fmaps = 0
219
+ for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
220
+ assert feat_fake.shape == feat_real.shape
221
+ n_fmaps += 1
222
+ feat_loss += self.loss(feat_fake, feat_real)
223
+ feat_scale += torch.mean(torch.abs(feat_real))
224
+
225
+ if self.normalize:
226
+ feat_loss /= n_fmaps
227
+
228
+ return feat_loss
audiocraft/data/__init__.py CHANGED
@@ -3,6 +3,8 @@
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
 
 
6
 
7
  # flake8: noqa
8
- from . import audio, audio_dataset
 
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
+ """Audio loading and writing support. Datasets for raw audio
7
+ or also including some metadata."""
8
 
9
  # flake8: noqa
10
+ from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset
audiocraft/data/audio.py CHANGED
@@ -18,11 +18,11 @@ import numpy as np
18
  import soundfile
19
  import torch
20
  from torch.nn import functional as F
21
- import torchaudio as ta
22
 
23
  import av
 
24
 
25
- from .audio_utils import f32_pcm, i16_pcm, normalize_audio
26
 
27
 
28
  _av_initialized = False
@@ -78,7 +78,7 @@ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: floa
78
  seek_time (float): Time at which to start reading in the file.
79
  duration (float): Duration to read from the file. If set to -1, the whole file is read.
80
  Returns:
81
- Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate
82
  """
83
  _init_av()
84
  with av.open(str(filepath)) as af:
@@ -123,7 +123,7 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
123
  duration (float): Duration to read from the file. If set to -1, the whole file is read.
124
  pad (bool): Pad output audio if not reaching expected duration.
125
  Returns:
126
- Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate.
127
  """
128
  fp = Path(filepath)
129
  if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
@@ -136,12 +136,6 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
136
  wav = torch.from_numpy(wav).t().contiguous()
137
  if len(wav.shape) == 1:
138
  wav = torch.unsqueeze(wav, 0)
139
- elif (
140
- fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
141
- and duration <= 0 and seek_time == 0
142
- ):
143
- # Torchaudio is faster if we load an entire file at once.
144
- wav, sr = ta.load(fp)
145
  else:
146
  wav, sr = _av_read(filepath, seek_time, duration)
147
  if pad and duration > 0:
@@ -150,10 +144,22 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
150
  return wav, sr
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def audio_write(stem_name: tp.Union[str, Path],
154
  wav: torch.Tensor, sample_rate: int,
155
- format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
156
- strategy: str = 'peak', peak_clip_headroom_db: float = 1,
157
  rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
158
  loudness_compressor: bool = False,
159
  log_clipping: bool = True, make_parent_dir: bool = True,
@@ -162,8 +168,11 @@ def audio_write(stem_name: tp.Union[str, Path],
162
 
163
  Args:
164
  stem_name (str or Path): Filename without extension which will be added automatically.
165
- format (str): Either "wav" or "mp3".
 
 
166
  mp3_rate (int): kbps when using mp3s.
 
167
  normalize (bool): if `True` (default), normalizes according to the prescribed
168
  strategy (see after). If `False`, the strategy is only used in case clipping
169
  would happen.
@@ -175,7 +184,7 @@ def audio_write(stem_name: tp.Union[str, Path],
175
  than the `peak_clip` one to avoid further clipping.
176
  loudness_headroom_db (float): Target loudness for loudness normalization.
177
  loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
178
- when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
179
  occurs despite strategy (only for 'rms').
180
  make_parent_dir (bool): Make parent directory if it doesn't exist.
181
  Returns:
@@ -188,16 +197,23 @@ def audio_write(stem_name: tp.Union[str, Path],
188
  raise ValueError("Input wav should be at most 2 dimension.")
189
  assert wav.isfinite().all()
190
  wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
191
- rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping,
192
- sample_rate=sample_rate, stem_name=str(stem_name))
193
- kwargs: dict = {}
194
  if format == 'mp3':
195
  suffix = '.mp3'
196
- kwargs.update({"compression": mp3_rate})
197
  elif format == 'wav':
198
- wav = i16_pcm(wav)
199
  suffix = '.wav'
200
- kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
 
 
 
 
 
 
 
 
201
  else:
202
  raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
203
  if not add_suffix:
@@ -206,7 +222,7 @@ def audio_write(stem_name: tp.Union[str, Path],
206
  if make_parent_dir:
207
  path.parent.mkdir(exist_ok=True, parents=True)
208
  try:
209
- ta.save(path, wav, sample_rate, **kwargs)
210
  except Exception:
211
  if path.exists():
212
  # we do not want to leave half written files around.
 
18
  import soundfile
19
  import torch
20
  from torch.nn import functional as F
 
21
 
22
  import av
23
+ import subprocess as sp
24
 
25
+ from .audio_utils import f32_pcm, normalize_audio
26
 
27
 
28
  _av_initialized = False
 
78
  seek_time (float): Time at which to start reading in the file.
79
  duration (float): Duration to read from the file. If set to -1, the whole file is read.
80
  Returns:
81
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate
82
  """
83
  _init_av()
84
  with av.open(str(filepath)) as af:
 
123
  duration (float): Duration to read from the file. If set to -1, the whole file is read.
124
  pad (bool): Pad output audio if not reaching expected duration.
125
  Returns:
126
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
127
  """
128
  fp = Path(filepath)
129
  if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
 
136
  wav = torch.from_numpy(wav).t().contiguous()
137
  if len(wav.shape) == 1:
138
  wav = torch.unsqueeze(wav, 0)
 
 
 
 
 
 
139
  else:
140
  wav, sr = _av_read(filepath, seek_time, duration)
141
  if pad and duration > 0:
 
144
  return wav, sr
145
 
146
 
147
+ def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
148
+ # ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
149
+ assert wav.dim() == 2, wav.shape
150
+ command = [
151
+ 'ffmpeg',
152
+ '-loglevel', 'error',
153
+ '-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
154
+ '-i', '-'] + flags + [str(out_path)]
155
+ input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
156
+ sp.run(command, input=input_, check=True)
157
+
158
+
159
  def audio_write(stem_name: tp.Union[str, Path],
160
  wav: torch.Tensor, sample_rate: int,
161
+ format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
162
+ normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
163
  rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
164
  loudness_compressor: bool = False,
165
  log_clipping: bool = True, make_parent_dir: bool = True,
 
168
 
169
  Args:
170
  stem_name (str or Path): Filename without extension which will be added automatically.
171
+ wav (torch.Tensor): Audio data to save.
172
+ sample_rate (int): Sample rate of audio data.
173
+ format (str): Either "wav", "mp3", "ogg", or "flac".
174
  mp3_rate (int): kbps when using mp3s.
175
+ ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
176
  normalize (bool): if `True` (default), normalizes according to the prescribed
177
  strategy (see after). If `False`, the strategy is only used in case clipping
178
  would happen.
 
184
  than the `peak_clip` one to avoid further clipping.
185
  loudness_headroom_db (float): Target loudness for loudness normalization.
186
  loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
187
+ when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
188
  occurs despite strategy (only for 'rms').
189
  make_parent_dir (bool): Make parent directory if it doesn't exist.
190
  Returns:
 
197
  raise ValueError("Input wav should be at most 2 dimension.")
198
  assert wav.isfinite().all()
199
  wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
200
+ rms_headroom_db, loudness_headroom_db, loudness_compressor,
201
+ log_clipping=log_clipping, sample_rate=sample_rate,
202
+ stem_name=str(stem_name))
203
  if format == 'mp3':
204
  suffix = '.mp3'
205
+ flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
206
  elif format == 'wav':
 
207
  suffix = '.wav'
208
+ flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
209
+ elif format == 'ogg':
210
+ suffix = '.ogg'
211
+ flags = ['-f', 'ogg', '-c:a', 'libvorbis']
212
+ if ogg_rate is not None:
213
+ flags += ['-b:a', f'{ogg_rate}k']
214
+ elif format == 'flac':
215
+ suffix = '.flac'
216
+ flags = ['-f', 'flac']
217
  else:
218
  raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
219
  if not add_suffix:
 
222
  if make_parent_dir:
223
  path.parent.mkdir(exist_ok=True, parents=True)
224
  try:
225
+ _piping_to_ffmpeg(path, wav, sample_rate, flags)
226
  except Exception:
227
  if path.exists():
228
  # we do not want to leave half written files around.
audiocraft/data/audio_dataset.py CHANGED
@@ -3,12 +3,16 @@
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
-
 
 
 
7
  import argparse
8
  import copy
9
  from concurrent.futures import ThreadPoolExecutor, Future
10
  from dataclasses import dataclass, fields
11
  from contextlib import ExitStack
 
12
  import gzip
13
  import json
14
  import logging
@@ -81,9 +85,12 @@ class AudioMeta(BaseInfo):
81
  class SegmentInfo(BaseInfo):
82
  meta: AudioMeta
83
  seek_time: float
84
- n_frames: int # actual number of frames without padding
 
 
85
  total_frames: int # total number of frames, padding included
86
- sample_rate: int # actual sample rate
 
87
 
88
 
89
  DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
@@ -114,8 +121,8 @@ def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
114
 
115
  Args:
116
  m (AudioMeta): Audio meta to resolve.
117
- fast (bool): If True, uses a really fast check for determining if a file is already absolute or not.
118
- Only valid on Linux/Mac.
119
  Returns:
120
  AudioMeta: Audio meta with resolved path.
121
  """
@@ -151,7 +158,7 @@ def find_audio_files(path: tp.Union[Path, str],
151
  progress (bool): Whether to log progress on audio files collection.
152
  workers (int): number of parallel workers, if 0, use only the current thread.
153
  Returns:
154
- List[AudioMeta]: List of audio file path and its metadata.
155
  """
156
  audio_files = []
157
  futures: tp.List[Future] = []
@@ -203,7 +210,7 @@ def load_audio_meta(path: tp.Union[str, Path],
203
  resolve (bool): Whether to resolve the path from AudioMeta (default=True).
204
  fast (bool): activates some tricks to make things faster.
205
  Returns:
206
- List[AudioMeta]: List of audio file path and its total duration.
207
  """
208
  open_fn = gzip.open if str(path).lower().endswith('.gz') else open
209
  with open_fn(path, 'rb') as fp: # type: ignore
@@ -250,9 +257,14 @@ class AudioDataset:
250
  allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
251
  original audio meta.
252
 
 
 
 
 
 
253
  Args:
254
- meta (tp.List[AudioMeta]): List of audio files metadata.
255
- segment_duration (float): Optional segment duration of audio to load.
256
  If not specified, the dataset will load the full audio segment from the file.
257
  shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
258
  sample_rate (int): Target sample rate of the loaded audio samples.
@@ -266,10 +278,19 @@ class AudioDataset:
266
  is shorter than the desired segment.
267
  max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
268
  return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
269
- min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided
270
  audio shorter than this will be filtered out.
271
- max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided
272
  audio longer than this will be filtered out.
 
 
 
 
 
 
 
 
 
273
  """
274
  def __init__(self,
275
  meta: tp.List[AudioMeta],
@@ -285,16 +306,14 @@ class AudioDataset:
285
  max_read_retry: int = 10,
286
  return_info: bool = False,
287
  min_audio_duration: tp.Optional[float] = None,
288
- max_audio_duration: tp.Optional[float] = None
 
 
 
289
  ):
290
- assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.'
291
  assert segment_duration is None or segment_duration > 0
292
  assert segment_duration is None or min_segment_ratio >= 0
293
- logging.debug(f'sample_on_duration: {sample_on_duration}')
294
- logging.debug(f'sample_on_weight: {sample_on_weight}')
295
- logging.debug(f'pad: {pad}')
296
- logging.debug(f'min_segment_ratio: {min_segment_ratio}')
297
-
298
  self.segment_duration = segment_duration
299
  self.min_segment_ratio = min_segment_ratio
300
  self.max_audio_duration = max_audio_duration
@@ -317,13 +336,25 @@ class AudioDataset:
317
  self.sampling_probabilities = self._get_sampling_probabilities()
318
  self.max_read_retry = max_read_retry
319
  self.return_info = return_info
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  def __len__(self):
322
  return self.num_samples
323
 
324
  def _get_sampling_probabilities(self, normalized: bool = True):
325
- """Return the sampling probabilities for each file inside `self.meta`.
326
- """
327
  scores: tp.List[float] = []
328
  for file_meta in self.meta:
329
  score = 1.
@@ -337,12 +368,32 @@ class AudioDataset:
337
  probabilities /= probabilities.sum()
338
  return probabilities
339
 
340
- def sample_file(self, rng: torch.Generator) -> AudioMeta:
341
- """Sample a given file from `self.meta`. Can be overriden in subclasses.
 
 
 
 
 
 
 
 
 
342
  This is only called if `segment_duration` is not None.
343
 
344
  You must use the provided random number generator `rng` for reproducibility.
 
345
  """
 
 
 
 
 
 
 
 
 
 
346
  if not self.sample_on_weight and not self.sample_on_duration:
347
  file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
348
  else:
@@ -350,6 +401,15 @@ class AudioDataset:
350
 
351
  return self.meta[file_index]
352
 
 
 
 
 
 
 
 
 
 
353
  def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
354
  if self.segment_duration is None:
355
  file_meta = self.meta[index]
@@ -357,18 +417,22 @@ class AudioDataset:
357
  out = convert_audio(out, sr, self.sample_rate, self.channels)
358
  n_frames = out.shape[-1]
359
  segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
360
- sample_rate=self.sample_rate)
361
  else:
362
  rng = torch.Generator()
363
  if self.shuffle:
364
- # We use index, plus extra randomness
365
- rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
 
 
 
 
366
  else:
367
  # We only use index
368
  rng.manual_seed(index)
369
 
370
  for retry in range(self.max_read_retry):
371
- file_meta = self.sample_file(rng)
372
  # We add some variance in the file position even if audio file is smaller than segment
373
  # without ending up with empty segments
374
  max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
@@ -381,7 +445,7 @@ class AudioDataset:
381
  if self.pad:
382
  out = F.pad(out, (0, target_frames - n_frames))
383
  segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
384
- sample_rate=self.sample_rate)
385
  except Exception as exc:
386
  logger.warning("Error opening file %s: %r", file_meta.path, exc)
387
  if retry == self.max_read_retry - 1:
@@ -423,7 +487,7 @@ class AudioDataset:
423
  if to_pad:
424
  # Each wav could be of a different duration as they are not segmented.
425
  for i in range(len(samples)):
426
- # Determines the total legth of the signal with padding, so we update here as we pad.
427
  segment_infos[i].total_frames = max_len
428
  wavs[i] = _pad_wav(wavs[i])
429
 
@@ -436,9 +500,7 @@ class AudioDataset:
436
  return torch.stack(samples)
437
 
438
  def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
439
- """Filters out audio files with short durations.
440
- Removes from meta files that have durations that will not allow to samples examples from them.
441
- """
442
  orig_len = len(meta)
443
 
444
  # Filter data that is too short.
 
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
+ """AudioDataset support. In order to handle a larger number of files
7
+ without having to scan again the folders, we precompute some metadata
8
+ (filename, sample rate, duration), and use that to efficiently sample audio segments.
9
+ """
10
  import argparse
11
  import copy
12
  from concurrent.futures import ThreadPoolExecutor, Future
13
  from dataclasses import dataclass, fields
14
  from contextlib import ExitStack
15
+ from functools import lru_cache
16
  import gzip
17
  import json
18
  import logging
 
85
  class SegmentInfo(BaseInfo):
86
  meta: AudioMeta
87
  seek_time: float
88
+ # The following values are given once the audio is processed, e.g.
89
+ # at the target sample rate and target number of channels.
90
+ n_frames: int # actual number of frames without padding
91
  total_frames: int # total number of frames, padding included
92
+ sample_rate: int # actual sample rate
93
+ channels: int # number of audio channels.
94
 
95
 
96
  DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
 
121
 
122
  Args:
123
  m (AudioMeta): Audio meta to resolve.
124
+ fast (bool): If True, uses a really fast check for determining if a file
125
+ is already absolute or not. Only valid on Linux/Mac.
126
  Returns:
127
  AudioMeta: Audio meta with resolved path.
128
  """
 
158
  progress (bool): Whether to log progress on audio files collection.
159
  workers (int): number of parallel workers, if 0, use only the current thread.
160
  Returns:
161
+ list of AudioMeta: List of audio file path and its metadata.
162
  """
163
  audio_files = []
164
  futures: tp.List[Future] = []
 
210
  resolve (bool): Whether to resolve the path from AudioMeta (default=True).
211
  fast (bool): activates some tricks to make things faster.
212
  Returns:
213
+ list of AudioMeta: List of audio file path and its total duration.
214
  """
215
  open_fn = gzip.open if str(path).lower().endswith('.gz') else open
216
  with open_fn(path, 'rb') as fp: # type: ignore
 
257
  allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
258
  original audio meta.
259
 
260
+ Note that you can call `start_epoch(epoch)` in order to get
261
+ a deterministic "randomization" for `shuffle=True`.
262
+ For a given epoch and dataset index, this will always return the same extract.
263
+ You can get back some diversity by setting the `shuffle_seed` param.
264
+
265
  Args:
266
+ meta (list of AudioMeta): List of audio files metadata.
267
+ segment_duration (float, optional): Optional segment duration of audio to load.
268
  If not specified, the dataset will load the full audio segment from the file.
269
  shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
270
  sample_rate (int): Target sample rate of the loaded audio samples.
 
278
  is shorter than the desired segment.
279
  max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
280
  return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
281
+ min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
282
  audio shorter than this will be filtered out.
283
+ max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
284
  audio longer than this will be filtered out.
285
+ shuffle_seed (int): can be used to further randomize
286
+ load_wav (bool): if False, skip loading the wav but returns a tensor of 0
287
+ with the expected segment_duration (which must be provided if load_wav is False).
288
+ permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
289
+ are False. Will ensure a permutation on files when going through the dataset.
290
+ In that case the epoch number must be provided in order for the model
291
+ to continue the permutation across epochs. In that case, it is assumed
292
+ that `num_samples = total_batch_size * num_updates_per_epoch`, with
293
+ `total_batch_size` the overall batch size accounting for all gpus.
294
  """
295
  def __init__(self,
296
  meta: tp.List[AudioMeta],
 
306
  max_read_retry: int = 10,
307
  return_info: bool = False,
308
  min_audio_duration: tp.Optional[float] = None,
309
+ max_audio_duration: tp.Optional[float] = None,
310
+ shuffle_seed: int = 0,
311
+ load_wav: bool = True,
312
+ permutation_on_files: bool = False,
313
  ):
314
+ assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
315
  assert segment_duration is None or segment_duration > 0
316
  assert segment_duration is None or min_segment_ratio >= 0
 
 
 
 
 
317
  self.segment_duration = segment_duration
318
  self.min_segment_ratio = min_segment_ratio
319
  self.max_audio_duration = max_audio_duration
 
336
  self.sampling_probabilities = self._get_sampling_probabilities()
337
  self.max_read_retry = max_read_retry
338
  self.return_info = return_info
339
+ self.shuffle_seed = shuffle_seed
340
+ self.current_epoch: tp.Optional[int] = None
341
+ self.load_wav = load_wav
342
+ if not load_wav:
343
+ assert segment_duration is not None
344
+ self.permutation_on_files = permutation_on_files
345
+ if permutation_on_files:
346
+ assert not self.sample_on_duration
347
+ assert not self.sample_on_weight
348
+ assert self.shuffle
349
+
350
+ def start_epoch(self, epoch: int):
351
+ self.current_epoch = epoch
352
 
353
  def __len__(self):
354
  return self.num_samples
355
 
356
  def _get_sampling_probabilities(self, normalized: bool = True):
357
+ """Return the sampling probabilities for each file inside `self.meta`."""
 
358
  scores: tp.List[float] = []
359
  for file_meta in self.meta:
360
  score = 1.
 
368
  probabilities /= probabilities.sum()
369
  return probabilities
370
 
371
+ @staticmethod
372
+ @lru_cache(16)
373
+ def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
374
+ # Used to keep the most recent files permutation in memory implicitely.
375
+ # will work unless someone is using a lot of Datasets in parallel.
376
+ rng = torch.Generator()
377
+ rng.manual_seed(base_seed + permutation_index)
378
+ return torch.randperm(num_files, generator=rng)
379
+
380
+ def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
381
+ """Sample a given file from `self.meta`. Can be overridden in subclasses.
382
  This is only called if `segment_duration` is not None.
383
 
384
  You must use the provided random number generator `rng` for reproducibility.
385
+ You can further make use of the index accessed.
386
  """
387
+ if self.permutation_on_files:
388
+ assert self.current_epoch is not None
389
+ total_index = self.current_epoch * len(self) + index
390
+ permutation_index = total_index // len(self.meta)
391
+ relative_index = total_index % len(self.meta)
392
+ permutation = AudioDataset._get_file_permutation(
393
+ len(self.meta), permutation_index, self.shuffle_seed)
394
+ file_index = permutation[relative_index]
395
+ return self.meta[file_index]
396
+
397
  if not self.sample_on_weight and not self.sample_on_duration:
398
  file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
399
  else:
 
401
 
402
  return self.meta[file_index]
403
 
404
+ def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
405
+ # Override this method in subclass if needed.
406
+ if self.load_wav:
407
+ return audio_read(path, seek_time, duration, pad=False)
408
+ else:
409
+ assert self.segment_duration is not None
410
+ n_frames = int(self.sample_rate * self.segment_duration)
411
+ return torch.zeros(self.channels, n_frames), self.sample_rate
412
+
413
  def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
414
  if self.segment_duration is None:
415
  file_meta = self.meta[index]
 
417
  out = convert_audio(out, sr, self.sample_rate, self.channels)
418
  n_frames = out.shape[-1]
419
  segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
420
+ sample_rate=self.sample_rate, channels=out.shape[0])
421
  else:
422
  rng = torch.Generator()
423
  if self.shuffle:
424
+ # We use index, plus extra randomness, either totally random if we don't know the epoch.
425
+ # otherwise we make use of the epoch number and optional shuffle_seed.
426
+ if self.current_epoch is None:
427
+ rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
428
+ else:
429
+ rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
430
  else:
431
  # We only use index
432
  rng.manual_seed(index)
433
 
434
  for retry in range(self.max_read_retry):
435
+ file_meta = self.sample_file(index, rng)
436
  # We add some variance in the file position even if audio file is smaller than segment
437
  # without ending up with empty segments
438
  max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
 
445
  if self.pad:
446
  out = F.pad(out, (0, target_frames - n_frames))
447
  segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
448
+ sample_rate=self.sample_rate, channels=out.shape[0])
449
  except Exception as exc:
450
  logger.warning("Error opening file %s: %r", file_meta.path, exc)
451
  if retry == self.max_read_retry - 1:
 
487
  if to_pad:
488
  # Each wav could be of a different duration as they are not segmented.
489
  for i in range(len(samples)):
490
+ # Determines the total length of the signal with padding, so we update here as we pad.
491
  segment_infos[i].total_frames = max_len
492
  wavs[i] = _pad_wav(wavs[i])
493
 
 
500
  return torch.stack(samples)
501
 
502
  def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
503
+ """Filters out audio files with audio durations that will not allow to sample examples from them."""
 
 
504
  orig_len = len(meta)
505
 
506
  # Filter data that is too short.
audiocraft/data/audio_utils.py CHANGED
@@ -3,7 +3,8 @@
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
-
 
7
  import sys
8
  import typing as tp
9
 
@@ -47,8 +48,7 @@ def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor
47
 
48
  def convert_audio(wav: torch.Tensor, from_rate: float,
49
  to_rate: float, to_channels: int) -> torch.Tensor:
50
- """Convert audio to new sample rate and number of audio channels.
51
- """
52
  wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
53
  wav = convert_audio_channels(wav, to_channels)
54
  return wav
@@ -66,7 +66,7 @@ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db
66
  loudness_compressor (bool): Uses tanh for soft clipping.
67
  energy_floor (float): anything below that RMS level will not be rescaled.
68
  Returns:
69
- output (torch.Tensor): Loudness normalized output data.
70
  """
71
  energy = wav.pow(2).mean().sqrt().item()
72
  if energy < energy_floor:
@@ -117,7 +117,7 @@ def normalize_audio(wav: torch.Tensor, normalize: bool = True,
117
  log_clipping (bool): If True, basic logging on stderr when clipping still
118
  occurs despite strategy (only for 'rms').
119
  sample_rate (int): Sample rate for the audio data (required for loudness).
120
- stem_name (Optional[str]): Stem name for clipping logging.
121
  Returns:
122
  torch.Tensor: Normalized audio.
123
  """
@@ -150,17 +150,19 @@ def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
150
  """
151
  if wav.dtype.is_floating_point:
152
  return wav
153
- else:
154
- assert wav.dtype == torch.int16
155
  return wav.float() / 2**15
 
 
 
156
 
157
 
158
  def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
159
  """Convert audio to int 16 bits PCM format.
160
 
161
- ..Warning:: There exist many formula for doing this convertion. None are perfect
162
- due to the asymetry of the int16 range. One either have possible clipping, DC offset,
163
- or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom,
164
  it is possible that `i16_pcm(f32_pcm)) != Identity`.
165
  """
166
  if wav.dtype.is_floating_point:
 
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
+ """Various utilities for audio convertion (pcm format, sample rate and channels),
7
+ and volume normalization."""
8
  import sys
9
  import typing as tp
10
 
 
48
 
49
  def convert_audio(wav: torch.Tensor, from_rate: float,
50
  to_rate: float, to_channels: int) -> torch.Tensor:
51
+ """Convert audio to new sample rate and number of audio channels."""
 
52
  wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
53
  wav = convert_audio_channels(wav, to_channels)
54
  return wav
 
66
  loudness_compressor (bool): Uses tanh for soft clipping.
67
  energy_floor (float): anything below that RMS level will not be rescaled.
68
  Returns:
69
+ torch.Tensor: Loudness normalized output data.
70
  """
71
  energy = wav.pow(2).mean().sqrt().item()
72
  if energy < energy_floor:
 
117
  log_clipping (bool): If True, basic logging on stderr when clipping still
118
  occurs despite strategy (only for 'rms').
119
  sample_rate (int): Sample rate for the audio data (required for loudness).
120
+ stem_name (str, optional): Stem name for clipping logging.
121
  Returns:
122
  torch.Tensor: Normalized audio.
123
  """
 
150
  """
151
  if wav.dtype.is_floating_point:
152
  return wav
153
+ elif wav.dtype == torch.int16:
 
154
  return wav.float() / 2**15
155
+ elif wav.dtype == torch.int32:
156
+ return wav.float() / 2**31
157
+ raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
158
 
159
 
160
  def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
161
  """Convert audio to int 16 bits PCM format.
162
 
163
+ ..Warning:: There exist many formula for doing this conversion. None are perfect
164
+ due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
165
+ or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
166
  it is possible that `i16_pcm(f32_pcm)) != Identity`.
167
  """
168
  if wav.dtype.is_floating_point:
audiocraft/data/info_audio_dataset.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Base classes for the datasets that also provide non-audio metadata,
7
+ e.g. description, text transcription etc.
8
+ """
9
+ from dataclasses import dataclass
10
+ import logging
11
+ import math
12
+ import re
13
+ import typing as tp
14
+
15
+ import torch
16
+
17
+ from .audio_dataset import AudioDataset, AudioMeta
18
+ from ..environment import AudioCraftEnvironment
19
+ from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
26
+ """Monkey-patch meta to match cluster specificities."""
27
+ meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
28
+ if meta.info_path is not None:
29
+ meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
30
+ return meta
31
+
32
+
33
+ def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
34
+ """Monkey-patch all meta to match cluster specificities."""
35
+ return [_clusterify_meta(m) for m in meta]
36
+
37
+
38
+ @dataclass
39
+ class AudioInfo(SegmentWithAttributes):
40
+ """Dummy SegmentInfo with empty attributes.
41
+
42
+ The InfoAudioDataset is expected to return metadata that inherits
43
+ from SegmentWithAttributes class and can return conditioning attributes.
44
+
45
+ This basically guarantees all datasets will be compatible with current
46
+ solver that contain conditioners requiring this.
47
+ """
48
+ audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
49
+
50
+ def to_condition_attributes(self) -> ConditioningAttributes:
51
+ return ConditioningAttributes()
52
+
53
+
54
+ class InfoAudioDataset(AudioDataset):
55
+ """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
56
+
57
+ See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
58
+ """
59
+ def __init__(self, meta: tp.List[AudioMeta], **kwargs):
60
+ super().__init__(clusterify_all_meta(meta), **kwargs)
61
+
62
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
63
+ if not self.return_info:
64
+ wav = super().__getitem__(index)
65
+ assert isinstance(wav, torch.Tensor)
66
+ return wav
67
+ wav, meta = super().__getitem__(index)
68
+ return wav, AudioInfo(**meta.to_dict())
69
+
70
+
71
+ def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
72
+ """Preprocess a single keyword or possible a list of keywords."""
73
+ if isinstance(value, list):
74
+ return get_keyword_list(value)
75
+ else:
76
+ return get_keyword(value)
77
+
78
+
79
+ def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
80
+ """Preprocess a single keyword."""
81
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
82
+ return None
83
+ else:
84
+ return value.strip()
85
+
86
+
87
+ def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
88
+ """Preprocess a single keyword."""
89
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
90
+ return None
91
+ else:
92
+ return value.strip().lower()
93
+
94
+
95
+ def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
96
+ """Preprocess a list of keywords."""
97
+ if isinstance(values, str):
98
+ values = [v.strip() for v in re.split(r'[,\s]', values)]
99
+ elif isinstance(values, float) and math.isnan(values):
100
+ values = []
101
+ if not isinstance(values, list):
102
+ logger.debug(f"Unexpected keyword list {values}")
103
+ values = [str(values)]
104
+
105
+ kws = [get_keyword(v) for v in values]
106
+ kw_list = [k for k in kws if k is not None]
107
+ if len(kw_list) == 0:
108
+ return None
109
+ else:
110
+ return kw_list
audiocraft/data/music_dataset.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Dataset of music tracks with rich metadata.
7
+ """
8
+ from dataclasses import dataclass, field, fields, replace
9
+ import gzip
10
+ import json
11
+ import logging
12
+ from pathlib import Path
13
+ import random
14
+ import typing as tp
15
+
16
+ import torch
17
+
18
+ from .info_audio_dataset import (
19
+ InfoAudioDataset,
20
+ AudioInfo,
21
+ get_keyword_list,
22
+ get_keyword,
23
+ get_string
24
+ )
25
+ from ..modules.conditioners import (
26
+ ConditioningAttributes,
27
+ JointEmbedCondition,
28
+ WavCondition,
29
+ )
30
+ from ..utils.utils import warn_once
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class MusicInfo(AudioInfo):
38
+ """Segment info augmented with music metadata.
39
+ """
40
+ # music-specific metadata
41
+ title: tp.Optional[str] = None
42
+ artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits
43
+ key: tp.Optional[str] = None
44
+ bpm: tp.Optional[float] = None
45
+ genre: tp.Optional[str] = None
46
+ moods: tp.Optional[list] = None
47
+ keywords: tp.Optional[list] = None
48
+ description: tp.Optional[str] = None
49
+ name: tp.Optional[str] = None
50
+ instrument: tp.Optional[str] = None
51
+ # original wav accompanying the metadata
52
+ self_wav: tp.Optional[WavCondition] = None
53
+ # dict mapping attributes names to tuple of wav, text and metadata
54
+ joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
55
+
56
+ @property
57
+ def has_music_meta(self) -> bool:
58
+ return self.name is not None
59
+
60
+ def to_condition_attributes(self) -> ConditioningAttributes:
61
+ out = ConditioningAttributes()
62
+ for _field in fields(self):
63
+ key, value = _field.name, getattr(self, _field.name)
64
+ if key == 'self_wav':
65
+ out.wav[key] = value
66
+ elif key == 'joint_embed':
67
+ for embed_attribute, embed_cond in value.items():
68
+ out.joint_embed[embed_attribute] = embed_cond
69
+ else:
70
+ if isinstance(value, list):
71
+ value = ' '.join(value)
72
+ out.text[key] = value
73
+ return out
74
+
75
+ @staticmethod
76
+ def attribute_getter(attribute):
77
+ if attribute == 'bpm':
78
+ preprocess_func = get_bpm
79
+ elif attribute == 'key':
80
+ preprocess_func = get_musical_key
81
+ elif attribute in ['moods', 'keywords']:
82
+ preprocess_func = get_keyword_list
83
+ elif attribute in ['genre', 'name', 'instrument']:
84
+ preprocess_func = get_keyword
85
+ elif attribute in ['title', 'artist', 'description']:
86
+ preprocess_func = get_string
87
+ else:
88
+ preprocess_func = None
89
+ return preprocess_func
90
+
91
+ @classmethod
92
+ def from_dict(cls, dictionary: dict, fields_required: bool = False):
93
+ _dictionary: tp.Dict[str, tp.Any] = {}
94
+
95
+ # allow a subset of attributes to not be loaded from the dictionary
96
+ # these attributes may be populated later
97
+ post_init_attributes = ['self_wav', 'joint_embed']
98
+ optional_fields = ['keywords']
99
+
100
+ for _field in fields(cls):
101
+ if _field.name in post_init_attributes:
102
+ continue
103
+ elif _field.name not in dictionary:
104
+ if fields_required and _field.name not in optional_fields:
105
+ raise KeyError(f"Unexpected missing key: {_field.name}")
106
+ else:
107
+ preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
108
+ value = dictionary[_field.name]
109
+ if preprocess_func:
110
+ value = preprocess_func(value)
111
+ _dictionary[_field.name] = value
112
+ return cls(**_dictionary)
113
+
114
+
115
+ def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
116
+ drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
117
+ """Augment MusicInfo description with additional metadata fields and potential dropout.
118
+ Additional textual attributes are added given probability 'merge_text_conditions_p' and
119
+ the original textual description is dropped from the augmented description given probability drop_desc_p.
120
+
121
+ Args:
122
+ music_info (MusicInfo): The music metadata to augment.
123
+ merge_text_p (float): Probability of merging additional metadata to the description.
124
+ If provided value is 0, then no merging is performed.
125
+ drop_desc_p (float): Probability of dropping the original description on text merge.
126
+ if provided value is 0, then no drop out is performed.
127
+ drop_other_p (float): Probability of dropping the other fields used for text augmentation.
128
+ Returns:
129
+ MusicInfo: The MusicInfo with augmented textual description.
130
+ """
131
+ def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
132
+ valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
133
+ valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
134
+ keep_field = random.uniform(0, 1) < drop_other_p
135
+ return valid_field_name and valid_field_value and keep_field
136
+
137
+ def process_value(v: tp.Any) -> str:
138
+ if isinstance(v, (int, float, str)):
139
+ return str(v)
140
+ if isinstance(v, list):
141
+ return ", ".join(v)
142
+ else:
143
+ raise ValueError(f"Unknown type for text value! ({type(v), v})")
144
+
145
+ description = music_info.description
146
+
147
+ metadata_text = ""
148
+ if random.uniform(0, 1) < merge_text_p:
149
+ meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
150
+ for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
151
+ random.shuffle(meta_pairs)
152
+ metadata_text = ". ".join(meta_pairs)
153
+ description = description if not random.uniform(0, 1) < drop_desc_p else None
154
+ logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
155
+
156
+ if description is None:
157
+ description = metadata_text if len(metadata_text) > 1 else None
158
+ else:
159
+ description = ". ".join([description.rstrip('.'), metadata_text])
160
+ description = description.strip() if description else None
161
+
162
+ music_info = replace(music_info)
163
+ music_info.description = description
164
+ return music_info
165
+
166
+
167
+ class Paraphraser:
168
+ def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
169
+ self.paraphrase_p = paraphrase_p
170
+ open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
171
+ with open_fn(paraphrase_source, 'rb') as f: # type: ignore
172
+ self.paraphrase_source = json.loads(f.read())
173
+ logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
174
+
175
+ def sample_paraphrase(self, audio_path: str, description: str):
176
+ if random.random() >= self.paraphrase_p:
177
+ return description
178
+ info_path = Path(audio_path).with_suffix('.json')
179
+ if info_path not in self.paraphrase_source:
180
+ warn_once(logger, f"{info_path} not in paraphrase source!")
181
+ return description
182
+ new_desc = random.choice(self.paraphrase_source[info_path])
183
+ logger.debug(f"{description} -> {new_desc}")
184
+ return new_desc
185
+
186
+
187
+ class MusicDataset(InfoAudioDataset):
188
+ """Music dataset is an AudioDataset with music-related metadata.
189
+
190
+ Args:
191
+ info_fields_required (bool): Whether to enforce having required fields.
192
+ merge_text_p (float): Probability of merging additional metadata to the description.
193
+ drop_desc_p (float): Probability of dropping the original description on text merge.
194
+ drop_other_p (float): Probability of dropping the other fields used for text augmentation.
195
+ joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
196
+ paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
197
+ paraphrases for the description. The json should be a dict with keys are the
198
+ original info path (e.g. track_path.json) and each value is a list of possible
199
+ paraphrased.
200
+ paraphrase_p (float): probability of taking a paraphrase.
201
+
202
+ See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
203
+ """
204
+ def __init__(self, *args, info_fields_required: bool = True,
205
+ merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
206
+ joint_embed_attributes: tp.List[str] = [],
207
+ paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
208
+ **kwargs):
209
+ kwargs['return_info'] = True # We require the info for each song of the dataset.
210
+ super().__init__(*args, **kwargs)
211
+ self.info_fields_required = info_fields_required
212
+ self.merge_text_p = merge_text_p
213
+ self.drop_desc_p = drop_desc_p
214
+ self.drop_other_p = drop_other_p
215
+ self.joint_embed_attributes = joint_embed_attributes
216
+ self.paraphraser = None
217
+ if paraphrase_source is not None:
218
+ self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
219
+
220
+ def __getitem__(self, index):
221
+ wav, info = super().__getitem__(index)
222
+ info_data = info.to_dict()
223
+ music_info_path = Path(info.meta.path).with_suffix('.json')
224
+
225
+ if Path(music_info_path).exists():
226
+ with open(music_info_path, 'r') as json_file:
227
+ music_data = json.load(json_file)
228
+ music_data.update(info_data)
229
+ music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
230
+ if self.paraphraser is not None:
231
+ music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
232
+ if self.merge_text_p:
233
+ music_info = augment_music_info_description(
234
+ music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
235
+ else:
236
+ music_info = MusicInfo.from_dict(info_data, fields_required=False)
237
+
238
+ music_info.self_wav = WavCondition(
239
+ wav=wav[None], length=torch.tensor([info.n_frames]),
240
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
241
+
242
+ for att in self.joint_embed_attributes:
243
+ att_value = getattr(music_info, att)
244
+ joint_embed_cond = JointEmbedCondition(
245
+ wav[None], [att_value], torch.tensor([info.n_frames]),
246
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
247
+ music_info.joint_embed[att] = joint_embed_cond
248
+
249
+ return wav, music_info
250
+
251
+
252
+ def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
253
+ """Preprocess key keywords, discarding them if there are multiple key defined."""
254
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
255
+ return None
256
+ elif ',' in value:
257
+ # For now, we discard when multiple keys are defined separated with comas
258
+ return None
259
+ else:
260
+ return value.strip().lower()
261
+
262
+
263
+ def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
264
+ """Preprocess to a float."""
265
+ if value is None:
266
+ return None
267
+ try:
268
+ return float(value)
269
+ except ValueError:
270
+ return None
audiocraft/data/sound_dataset.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Dataset of audio with a simple description.
7
+ """
8
+
9
+ from dataclasses import dataclass, fields, replace
10
+ import json
11
+ from pathlib import Path
12
+ import random
13
+ import typing as tp
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ from .info_audio_dataset import (
19
+ InfoAudioDataset,
20
+ get_keyword_or_keyword_list
21
+ )
22
+ from ..modules.conditioners import (
23
+ ConditioningAttributes,
24
+ SegmentWithAttributes,
25
+ WavCondition,
26
+ )
27
+
28
+
29
+ EPS = torch.finfo(torch.float32).eps
30
+ TARGET_LEVEL_LOWER = -35
31
+ TARGET_LEVEL_UPPER = -15
32
+
33
+
34
+ @dataclass
35
+ class SoundInfo(SegmentWithAttributes):
36
+ """Segment info augmented with Sound metadata.
37
+ """
38
+ description: tp.Optional[str] = None
39
+ self_wav: tp.Optional[torch.Tensor] = None
40
+
41
+ @property
42
+ def has_sound_meta(self) -> bool:
43
+ return self.description is not None
44
+
45
+ def to_condition_attributes(self) -> ConditioningAttributes:
46
+ out = ConditioningAttributes()
47
+
48
+ for _field in fields(self):
49
+ key, value = _field.name, getattr(self, _field.name)
50
+ if key == 'self_wav':
51
+ out.wav[key] = value
52
+ else:
53
+ out.text[key] = value
54
+ return out
55
+
56
+ @staticmethod
57
+ def attribute_getter(attribute):
58
+ if attribute == 'description':
59
+ preprocess_func = get_keyword_or_keyword_list
60
+ else:
61
+ preprocess_func = None
62
+ return preprocess_func
63
+
64
+ @classmethod
65
+ def from_dict(cls, dictionary: dict, fields_required: bool = False):
66
+ _dictionary: tp.Dict[str, tp.Any] = {}
67
+
68
+ # allow a subset of attributes to not be loaded from the dictionary
69
+ # these attributes may be populated later
70
+ post_init_attributes = ['self_wav']
71
+
72
+ for _field in fields(cls):
73
+ if _field.name in post_init_attributes:
74
+ continue
75
+ elif _field.name not in dictionary:
76
+ if fields_required:
77
+ raise KeyError(f"Unexpected missing key: {_field.name}")
78
+ else:
79
+ preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
80
+ value = dictionary[_field.name]
81
+ if preprocess_func:
82
+ value = preprocess_func(value)
83
+ _dictionary[_field.name] = value
84
+ return cls(**_dictionary)
85
+
86
+
87
+ class SoundDataset(InfoAudioDataset):
88
+ """Sound audio dataset: Audio dataset with environmental sound-specific metadata.
89
+
90
+ Args:
91
+ info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
92
+ external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
93
+ The metadata files contained in this folder are expected to match the stem of the audio file with
94
+ a json extension.
95
+ aug_p (float): Probability of performing audio mixing augmentation on the batch.
96
+ mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
97
+ mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
98
+ mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
99
+ mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
100
+ kwargs: Additional arguments for AudioDataset.
101
+
102
+ See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
103
+ """
104
+ def __init__(
105
+ self,
106
+ *args,
107
+ info_fields_required: bool = True,
108
+ external_metadata_source: tp.Optional[str] = None,
109
+ aug_p: float = 0.,
110
+ mix_p: float = 0.,
111
+ mix_snr_low: int = -5,
112
+ mix_snr_high: int = 5,
113
+ mix_min_overlap: float = 0.5,
114
+ **kwargs
115
+ ):
116
+ kwargs['return_info'] = True # We require the info for each song of the dataset.
117
+ super().__init__(*args, **kwargs)
118
+ self.info_fields_required = info_fields_required
119
+ self.external_metadata_source = external_metadata_source
120
+ self.aug_p = aug_p
121
+ self.mix_p = mix_p
122
+ if self.aug_p > 0:
123
+ assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
124
+ assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
125
+ self.mix_snr_low = mix_snr_low
126
+ self.mix_snr_high = mix_snr_high
127
+ self.mix_min_overlap = mix_min_overlap
128
+
129
+ def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
130
+ """Get path of JSON with metadata (description, etc.).
131
+ If there exists a JSON with the same name as 'path.name', then it will be used.
132
+ Else, such JSON will be searched for in an external json source folder if it exists.
133
+ """
134
+ info_path = Path(path).with_suffix('.json')
135
+ if Path(info_path).exists():
136
+ return info_path
137
+ elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
138
+ return Path(self.external_metadata_source) / info_path.name
139
+ else:
140
+ raise Exception(f"Unable to find a metadata JSON for path: {path}")
141
+
142
+ def __getitem__(self, index):
143
+ wav, info = super().__getitem__(index)
144
+ info_data = info.to_dict()
145
+ info_path = self._get_info_path(info.meta.path)
146
+ if Path(info_path).exists():
147
+ with open(info_path, 'r') as json_file:
148
+ sound_data = json.load(json_file)
149
+ sound_data.update(info_data)
150
+ sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
151
+ # if there are multiple descriptions, sample one randomly
152
+ if isinstance(sound_info.description, list):
153
+ sound_info.description = random.choice(sound_info.description)
154
+ else:
155
+ sound_info = SoundInfo.from_dict(info_data, fields_required=False)
156
+
157
+ sound_info.self_wav = WavCondition(
158
+ wav=wav[None], length=torch.tensor([info.n_frames]),
159
+ sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
160
+
161
+ return wav, sound_info
162
+
163
+ def collater(self, samples):
164
+ # when training, audio mixing is performed in the collate function
165
+ wav, sound_info = super().collater(samples) # SoundDataset always returns infos
166
+ if self.aug_p > 0:
167
+ wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
168
+ snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
169
+ min_overlap=self.mix_min_overlap)
170
+ return wav, sound_info
171
+
172
+
173
+ def rms_f(x: torch.Tensor) -> torch.Tensor:
174
+ return (x ** 2).mean(1).pow(0.5)
175
+
176
+
177
+ def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
178
+ """Normalize the signal to the target level."""
179
+ rms = rms_f(audio)
180
+ scalar = 10 ** (target_level / 20) / (rms + EPS)
181
+ audio = audio * scalar.unsqueeze(1)
182
+ return audio
183
+
184
+
185
+ def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
186
+ return (abs(audio) > clipping_threshold).any(1)
187
+
188
+
189
+ def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
190
+ start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
191
+ remainder = src.shape[1] - start
192
+ if dst.shape[1] > remainder:
193
+ src[:, start:] = src[:, start:] + dst[:, :remainder]
194
+ else:
195
+ src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
196
+ return src
197
+
198
+
199
+ def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
200
+ target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
201
+ """Function to mix clean speech and noise at various SNR levels.
202
+
203
+ Args:
204
+ clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
205
+ noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
206
+ snr (int): SNR level when mixing.
207
+ min_overlap (float): Minimum overlap between the two mixed sources.
208
+ target_level (int): Gain level in dB.
209
+ clipping_threshold (float): Threshold for clipping the audio.
210
+ Returns:
211
+ torch.Tensor: The mixed audio, of shape [B, T].
212
+ """
213
+ if clean.shape[1] > noise.shape[1]:
214
+ noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
215
+ else:
216
+ noise = noise[:, :clean.shape[1]]
217
+
218
+ # normalizing to -25 dB FS
219
+ clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
220
+ clean = normalize(clean, target_level)
221
+ rmsclean = rms_f(clean)
222
+
223
+ noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
224
+ noise = normalize(noise, target_level)
225
+ rmsnoise = rms_f(noise)
226
+
227
+ # set the noise level for a given SNR
228
+ noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
229
+ noisenewlevel = noise * noisescalar
230
+
231
+ # mix noise and clean speech
232
+ noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
233
+
234
+ # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
235
+ # there is a chance of clipping that might happen with very less probability, which is not a major issue.
236
+ noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
237
+ rmsnoisy = rms_f(noisyspeech)
238
+ scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
239
+ noisyspeech = noisyspeech * scalarnoisy
240
+ clean = clean * scalarnoisy
241
+ noisenewlevel = noisenewlevel * scalarnoisy
242
+
243
+ # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
244
+ clipped = is_clipped(noisyspeech)
245
+ if clipped.any():
246
+ noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
247
+ noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
248
+
249
+ return noisyspeech
250
+
251
+
252
+ def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
253
+ if snr_low == snr_high:
254
+ snr = snr_low
255
+ else:
256
+ snr = np.random.randint(snr_low, snr_high)
257
+ mix = snr_mixer(src, dst, snr, min_overlap)
258
+ return mix
259
+
260
+
261
+ def mix_text(src_text: str, dst_text: str):
262
+ """Mix text from different sources by concatenating them."""
263
+ if src_text == dst_text:
264
+ return src_text
265
+ return src_text + " " + dst_text
266
+
267
+
268
+ def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
269
+ snr_low: int, snr_high: int, min_overlap: float):
270
+ """Mix samples within a batch, summing the waveforms and concatenating the text infos.
271
+
272
+ Args:
273
+ wavs (torch.Tensor): Audio tensors of shape [B, C, T].
274
+ infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
275
+ aug_p (float): Augmentation probability.
276
+ mix_p (float): Proportion of items in the batch to mix (and merge) together.
277
+ snr_low (int): Lowerbound for sampling SNR.
278
+ snr_high (int): Upperbound for sampling SNR.
279
+ min_overlap (float): Minimum overlap between mixed samples.
280
+ Returns:
281
+ tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
282
+ and mixed SoundInfo for the given batch.
283
+ """
284
+ # no mixing to perform within the batch
285
+ if mix_p == 0:
286
+ return wavs, infos
287
+
288
+ if random.uniform(0, 1) < aug_p:
289
+ # perform all augmentations on waveforms as [B, T]
290
+ # randomly picking pairs of audio to mix
291
+ assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
292
+ wavs = wavs.mean(dim=1, keepdim=False)
293
+ B, T = wavs.shape
294
+ k = int(mix_p * B)
295
+ mixed_sources_idx = torch.randperm(B)[:k]
296
+ mixed_targets_idx = torch.randperm(B)[:k]
297
+ aug_wavs = snr_mix(
298
+ wavs[mixed_sources_idx],
299
+ wavs[mixed_targets_idx],
300
+ snr_low,
301
+ snr_high,
302
+ min_overlap,
303
+ )
304
+ # mixing textual descriptions in metadata
305
+ descriptions = [info.description for info in infos]
306
+ aug_infos = []
307
+ for i, j in zip(mixed_sources_idx, mixed_targets_idx):
308
+ text = mix_text(descriptions[i], descriptions[j])
309
+ m = replace(infos[i])
310
+ m.description = text
311
+ aug_infos.append(m)
312
+
313
+ # back to [B, C, T]
314
+ aug_wavs = aug_wavs.unsqueeze(1)
315
+ assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
316
+ assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
317
+ assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
318
+
319
+ return aug_wavs, aug_infos # [B, C, T]
320
+ else:
321
+ # randomly pick samples in the batch to match
322
+ # the batch size when performing audio mixing
323
+ B, C, T = wavs.shape
324
+ k = int(mix_p * B)
325
+ wav_idx = torch.randperm(B)[:k]
326
+ wavs = wavs[wav_idx]
327
+ infos = [infos[i] for i in wav_idx]
328
+ assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
329
+
330
+ return wavs, infos # [B, C, T]
audiocraft/data/zip.py CHANGED
@@ -3,6 +3,8 @@
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
 
 
6
 
7
  import typing
8
  import zipfile
@@ -18,13 +20,13 @@ MODE = Literal['r', 'w', 'x', 'a']
18
 
19
  @dataclass(order=True)
20
  class PathInZip:
21
- """Class for holding a path of file within a zip file.
22
 
23
  Args:
24
- path: The convention is <path_to_zip>:<relative_path_inside_zip>
25
  Let's assume there is a zip file /some/location/foo.zip
26
  and inside of it is a json file located at /data/file1.json,
27
- Then we expect path = "/some/location/foo.zip:/data/file1.json"
28
  """
29
 
30
  INFO_PATH_SEP = ':'
@@ -55,7 +57,7 @@ def set_zip_cache_size(max_size: int):
55
  """Sets the maximal LRU caching for zip file opening.
56
 
57
  Args:
58
- max_size: the maximal LRU cache.
59
  """
60
  global _cached_open_zip
61
  _cached_open_zip = lru_cache(max_size)(_open_zip)
@@ -65,8 +67,8 @@ def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
65
  """Opens a file stored inside a zip and returns a file-like object.
66
 
67
  Args:
68
- path_in_zip: A PathInZip object representing the file to return a file-like object of.
69
- mode: The mode in which to open the file with.
70
  Returns:
71
  A file-like object for PathInZip.
72
  """
 
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
+ """Utility for reading some info from inside a zip file.
7
+ """
8
 
9
  import typing
10
  import zipfile
 
20
 
21
  @dataclass(order=True)
22
  class PathInZip:
23
+ """Hold a path of file within a zip file.
24
 
25
  Args:
26
+ path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
27
  Let's assume there is a zip file /some/location/foo.zip
28
  and inside of it is a json file located at /data/file1.json,
29
+ Then we expect path = "/some/location/foo.zip:/data/file1.json".
30
  """
31
 
32
  INFO_PATH_SEP = ':'
 
57
  """Sets the maximal LRU caching for zip file opening.
58
 
59
  Args:
60
+ max_size (int): the maximal LRU cache.
61
  """
62
  global _cached_open_zip
63
  _cached_open_zip = lru_cache(max_size)(_open_zip)
 
67
  """Opens a file stored inside a zip and returns a file-like object.
68
 
69
  Args:
70
+ path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
71
+ mode (str): The mode in which to open the file with.
72
  Returns:
73
  A file-like object for PathInZip.
74
  """
audiocraft/environment.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Provides cluster and tools configuration across clusters (slurm, dora, utilities).
9
+ """
10
+
11
+ import logging
12
+ import os
13
+ from pathlib import Path
14
+ import re
15
+ import typing as tp
16
+
17
+ import omegaconf
18
+
19
+ from .utils.cluster import _guess_cluster_type
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class AudioCraftEnvironment:
26
+ """Environment configuration for teams and clusters.
27
+
28
+ AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
29
+ or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
30
+ provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
31
+ allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
32
+ map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
33
+
34
+ The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
35
+ Use the following environment variables to specify the cluster, team or configuration:
36
+
37
+ AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
38
+ cannot be inferred automatically.
39
+ AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
40
+ If not set, configuration is read from config/teams.yaml.
41
+ AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
42
+ Cluster configuration are shared across teams to match compute allocation,
43
+ specify your cluster configuration in the configuration file under a key mapping
44
+ your team name.
45
+ """
46
+ _instance = None
47
+ DEFAULT_TEAM = "default"
48
+
49
+ def __init__(self) -> None:
50
+ """Loads configuration."""
51
+ self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
52
+ cluster_type = _guess_cluster_type()
53
+ cluster = os.getenv(
54
+ "AUDIOCRAFT_CLUSTER", cluster_type.value
55
+ )
56
+ logger.info("Detecting cluster type %s", cluster_type)
57
+
58
+ self.cluster: str = cluster
59
+
60
+ config_path = os.getenv(
61
+ "AUDIOCRAFT_CONFIG",
62
+ Path(__file__)
63
+ .parent.parent.joinpath("config/teams", self.team)
64
+ .with_suffix(".yaml"),
65
+ )
66
+ self.config = omegaconf.OmegaConf.load(config_path)
67
+ self._dataset_mappers = []
68
+ cluster_config = self._get_cluster_config()
69
+ if "dataset_mappers" in cluster_config:
70
+ for pattern, repl in cluster_config["dataset_mappers"].items():
71
+ regex = re.compile(pattern)
72
+ self._dataset_mappers.append((regex, repl))
73
+
74
+ def _get_cluster_config(self) -> omegaconf.DictConfig:
75
+ assert isinstance(self.config, omegaconf.DictConfig)
76
+ return self.config[self.cluster]
77
+
78
+ @classmethod
79
+ def instance(cls):
80
+ if cls._instance is None:
81
+ cls._instance = cls()
82
+ return cls._instance
83
+
84
+ @classmethod
85
+ def reset(cls):
86
+ """Clears the environment and forces a reload on next invocation."""
87
+ cls._instance = None
88
+
89
+ @classmethod
90
+ def get_team(cls) -> str:
91
+ """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
92
+ If not defined, defaults to "labs".
93
+ """
94
+ return cls.instance().team
95
+
96
+ @classmethod
97
+ def get_cluster(cls) -> str:
98
+ """Gets the detected cluster.
99
+ This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
100
+ """
101
+ return cls.instance().cluster
102
+
103
+ @classmethod
104
+ def get_dora_dir(cls) -> Path:
105
+ """Gets the path to the dora directory for the current team and cluster.
106
+ Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
107
+ """
108
+ cluster_config = cls.instance()._get_cluster_config()
109
+ dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
110
+ logger.warning(f"Dora directory: {dora_dir}")
111
+ return Path(dora_dir)
112
+
113
+ @classmethod
114
+ def get_reference_dir(cls) -> Path:
115
+ """Gets the path to the reference directory for the current team and cluster.
116
+ Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
117
+ """
118
+ cluster_config = cls.instance()._get_cluster_config()
119
+ return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
120
+
121
+ @classmethod
122
+ def get_slurm_exclude(cls) -> tp.Optional[str]:
123
+ """Get the list of nodes to exclude for that cluster."""
124
+ cluster_config = cls.instance()._get_cluster_config()
125
+ return cluster_config.get("slurm_exclude")
126
+
127
+ @classmethod
128
+ def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
129
+ """Gets the requested partitions for the current team and cluster as a comma-separated string.
130
+
131
+ Args:
132
+ partition_types (list[str], optional): partition types to retrieve. Values must be
133
+ from ['global', 'team']. If not provided, the global partition is returned.
134
+ """
135
+ if not partition_types:
136
+ partition_types = ["global"]
137
+
138
+ cluster_config = cls.instance()._get_cluster_config()
139
+ partitions = [
140
+ cluster_config["partitions"][partition_type]
141
+ for partition_type in partition_types
142
+ ]
143
+ return ",".join(partitions)
144
+
145
+ @classmethod
146
+ def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
147
+ """Converts reference placeholder in path with configured reference dir to resolve paths.
148
+
149
+ Args:
150
+ path (str or Path): Path to resolve.
151
+ Returns:
152
+ Path: Resolved path.
153
+ """
154
+ path = str(path)
155
+
156
+ if path.startswith("//reference"):
157
+ reference_dir = cls.get_reference_dir()
158
+ logger.warn(f"Reference directory: {reference_dir}")
159
+ assert (
160
+ reference_dir.exists() and reference_dir.is_dir()
161
+ ), f"Reference directory does not exist: {reference_dir}."
162
+ path = re.sub("^//reference", str(reference_dir), path)
163
+
164
+ return Path(path)
165
+
166
+ @classmethod
167
+ def apply_dataset_mappers(cls, path: str) -> str:
168
+ """Applies dataset mapping regex rules as defined in the configuration.
169
+ If no rules are defined, the path is returned as-is.
170
+ """
171
+ instance = cls.instance()
172
+
173
+ for pattern, repl in instance._dataset_mappers:
174
+ path = pattern.sub(repl, path)
175
+
176
+ return path
audiocraft/grids/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Dora Grids."""
audiocraft/grids/_base_explorers.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+ import time
9
+ import typing as tp
10
+ from dora import Explorer
11
+ import treetable as tt
12
+
13
+
14
+ def get_sheep_ping(sheep) -> tp.Optional[str]:
15
+ """Return the amount of time since the Sheep made some update
16
+ to its log. Returns a str using the relevant time unit."""
17
+ ping = None
18
+ if sheep.log is not None and sheep.log.exists():
19
+ delta = time.time() - sheep.log.stat().st_mtime
20
+ if delta > 3600 * 24:
21
+ ping = f'{delta / (3600 * 24):.1f}d'
22
+ elif delta > 3600:
23
+ ping = f'{delta / (3600):.1f}h'
24
+ elif delta > 60:
25
+ ping = f'{delta / 60:.1f}m'
26
+ else:
27
+ ping = f'{delta:.1f}s'
28
+ return ping
29
+
30
+
31
+ class BaseExplorer(ABC, Explorer):
32
+ """Base explorer for AudioCraft grids.
33
+
34
+ All task specific solvers are expected to implement the `get_grid_metrics`
35
+ method to specify logic about metrics to display for a given task.
36
+
37
+ If additional stages are used, the child explorer must define how to handle
38
+ these new stages in the `process_history` and `process_sheep` methods.
39
+ """
40
+ def stages(self):
41
+ return ["train", "valid", "evaluate"]
42
+
43
+ def get_grid_meta(self):
44
+ """Returns the list of Meta information to display for each XP/job.
45
+ """
46
+ return [
47
+ tt.leaf("index", align=">"),
48
+ tt.leaf("name", wrap=140),
49
+ tt.leaf("state"),
50
+ tt.leaf("sig", align=">"),
51
+ tt.leaf("sid", align="<"),
52
+ ]
53
+
54
+ @abstractmethod
55
+ def get_grid_metrics(self):
56
+ """Return the metrics that should be displayed in the tracking table.
57
+ """
58
+ ...
59
+
60
+ def process_sheep(self, sheep, history):
61
+ train = {
62
+ "epoch": len(history),
63
+ }
64
+ parts = {"train": train}
65
+ for metrics in history:
66
+ for key, sub in metrics.items():
67
+ part = parts.get(key, {})
68
+ if 'duration' in sub:
69
+ # Convert to minutes for readability.
70
+ sub['duration'] = sub['duration'] / 60.
71
+ part.update(sub)
72
+ parts[key] = part
73
+ ping = get_sheep_ping(sheep)
74
+ if ping is not None:
75
+ for name in self.stages():
76
+ if name not in parts:
77
+ parts[name] = {}
78
+ # Add the ping to each part for convenience.
79
+ parts[name]['ping'] = ping
80
+ return parts
audiocraft/grids/audiogen/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """AudioGen grids."""
audiocraft/grids/audiogen/audiogen_base_16khz.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ..musicgen._explorers import LMExplorer
8
+ from ...environment import AudioCraftEnvironment
9
+
10
+
11
+ @LMExplorer
12
+ def explorer(launcher):
13
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14
+ launcher.slurm_(gpus=64, partition=partitions)
15
+ launcher.bind_(solver='audiogen/audiogen_base_16khz')
16
+ # replace this by the desired environmental sound dataset
17
+ launcher.bind_(dset='internal/sounds_16khz')
18
+
19
+ fsdp = {'autocast': False, 'fsdp.use': True}
20
+ medium = {'model/lm/model_scale': 'medium'}
21
+
22
+ launcher.bind_(fsdp)
23
+ launcher(medium)
audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Evaluation with objective metrics for the pretrained AudioGen models.
9
+ This grid takes signature from the training grid and runs evaluation-only stage.
10
+
11
+ When running the grid for the first time, please use:
12
+ REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval
13
+ and re-use the REGEN=1 option when the grid is changed to force regenerating it.
14
+
15
+ Note that you need the proper metrics external libraries setup to use all
16
+ the objective metrics activated in this grid. Refer to the README for more information.
17
+ """
18
+
19
+ import os
20
+
21
+ from ..musicgen._explorers import GenerationEvalExplorer
22
+ from ...environment import AudioCraftEnvironment
23
+ from ... import train
24
+
25
+
26
+ def eval(launcher, batch_size: int = 32):
27
+ opts = {
28
+ 'dset': 'audio/audiocaps_16khz',
29
+ 'solver/audiogen/evaluation': 'objective_eval',
30
+ 'execute_only': 'evaluate',
31
+ '+dataset.evaluate.batch_size': batch_size,
32
+ '+metrics.fad.tf.batch_size': 32,
33
+ }
34
+ # binary for FAD computation: replace this path with your own path
35
+ metrics_opts = {
36
+ 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
37
+ }
38
+ opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
39
+ opt2 = {'transformer_lm.two_step_cfg': True}
40
+
41
+ sub = launcher.bind(opts)
42
+ sub.bind_(metrics_opts)
43
+
44
+ # base objective metrics
45
+ sub(opt1, opt2)
46
+
47
+
48
+ @GenerationEvalExplorer
49
+ def explorer(launcher):
50
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
51
+ launcher.slurm_(gpus=4, partition=partitions)
52
+
53
+ if 'REGEN' not in os.environ:
54
+ folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
55
+ with launcher.job_array():
56
+ for sig in folder.iterdir():
57
+ if not sig.is_symlink():
58
+ continue
59
+ xp = train.main.get_xp_from_sig(sig.name)
60
+ launcher(xp.argv)
61
+ return
62
+
63
+ audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz")
64
+ audiogen_base.bind_({'autocast': False, 'fsdp.use': True})
65
+
66
+ audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'})
67
+ audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'})
68
+ eval(audiogen_base_medium, batch_size=128)
audiocraft/grids/compression/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """EnCodec grids."""
audiocraft/grids/compression/_explorers.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import treetable as tt
8
+
9
+ from .._base_explorers import BaseExplorer
10
+
11
+
12
+ class CompressionExplorer(BaseExplorer):
13
+ eval_metrics = ["sisnr", "visqol"]
14
+
15
+ def stages(self):
16
+ return ["train", "valid", "evaluate"]
17
+
18
+ def get_grid_meta(self):
19
+ """Returns the list of Meta information to display for each XP/job.
20
+ """
21
+ return [
22
+ tt.leaf("index", align=">"),
23
+ tt.leaf("name", wrap=140),
24
+ tt.leaf("state"),
25
+ tt.leaf("sig", align=">"),
26
+ ]
27
+
28
+ def get_grid_metrics(self):
29
+ """Return the metrics that should be displayed in the tracking table.
30
+ """
31
+ return [
32
+ tt.group(
33
+ "train",
34
+ [
35
+ tt.leaf("epoch"),
36
+ tt.leaf("bandwidth", ".2f"),
37
+ tt.leaf("adv", ".4f"),
38
+ tt.leaf("d_loss", ".4f"),
39
+ ],
40
+ align=">",
41
+ ),
42
+ tt.group(
43
+ "valid",
44
+ [
45
+ tt.leaf("bandwidth", ".2f"),
46
+ tt.leaf("adv", ".4f"),
47
+ tt.leaf("msspec", ".4f"),
48
+ tt.leaf("sisnr", ".2f"),
49
+ ],
50
+ align=">",
51
+ ),
52
+ tt.group(
53
+ "evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">"
54
+ ),
55
+ ]
audiocraft/grids/compression/debug.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Grid search file, simply list all the exp you want in `explorer`.
9
+ Any new exp added there will be scheduled.
10
+ You can cancel and experiment by commenting its line.
11
+
12
+ This grid is a minimal example for debugging compression task
13
+ and how to override parameters directly in a grid.
14
+ Learn more about dora grids: https://github.com/facebookresearch/dora
15
+ """
16
+
17
+ from ._explorers import CompressionExplorer
18
+ from ...environment import AudioCraftEnvironment
19
+
20
+
21
+ @CompressionExplorer
22
+ def explorer(launcher):
23
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
24
+ launcher.slurm_(gpus=2, partition=partitions)
25
+ launcher.bind_(solver='compression/debug')
26
+
27
+ with launcher.job_array():
28
+ # base debug task using config from solver=compression/debug
29
+ launcher()
30
+ # we can override parameters in the grid to launch additional xps
31
+ launcher({'rvq.bins': 2048, 'rvq.n_q': 4})
audiocraft/grids/compression/encodec_audiogen_16khz.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Grid search file, simply list all the exp you want in `explorer`.
9
+ Any new exp added there will be scheduled.
10
+ You can cancel and experiment by commenting its line.
11
+
12
+ This grid shows how to train the new AudioGen EnCodec model at 16 kHz.
13
+ """
14
+
15
+ from ._explorers import CompressionExplorer
16
+ from ...environment import AudioCraftEnvironment
17
+
18
+
19
+ @CompressionExplorer
20
+ def explorer(launcher):
21
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22
+ launcher.slurm_(gpus=8, partition=partitions)
23
+ # use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz
24
+ # AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz
25
+ launcher.bind_(solver='compression/encodec_audiogen_16khz')
26
+ # replace this by the desired sound dataset
27
+ launcher.bind_(dset='internal/sounds_16khz')
28
+ # launch xp
29
+ launcher()
audiocraft/grids/compression/encodec_base_24khz.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Grid search file, simply list all the exp you want in `explorer`.
9
+ Any new exp added there will be scheduled.
10
+ You can cancel and experiment by commenting its line.
11
+
12
+ This grid shows how to train a base causal EnCodec model at 24 kHz.
13
+ """
14
+
15
+ from ._explorers import CompressionExplorer
16
+ from ...environment import AudioCraftEnvironment
17
+
18
+
19
+ @CompressionExplorer
20
+ def explorer(launcher):
21
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22
+ launcher.slurm_(gpus=8, partition=partitions)
23
+ # base causal EnCodec trained on monophonic audio sampled at 24 kHz
24
+ launcher.bind_(solver='compression/encodec_base_24khz')
25
+ # replace this by the desired dataset
26
+ launcher.bind_(dset='audio/example')
27
+ # launch xp
28
+ launcher()
audiocraft/grids/compression/encodec_musicgen_32khz.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Grid search file, simply list all the exp you want in `explorer`.
9
+ Any new exp added there will be scheduled.
10
+ You can cancel and experiment by commenting its line.
11
+
12
+ This grid shows how to train a MusicGen EnCodec model at 32 kHz.
13
+ """
14
+
15
+ from ._explorers import CompressionExplorer
16
+ from ...environment import AudioCraftEnvironment
17
+
18
+
19
+ @CompressionExplorer
20
+ def explorer(launcher):
21
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22
+ launcher.slurm_(gpus=8, partition=partitions)
23
+ # use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz
24
+ # MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz
25
+ launcher.bind_(solver='compression/encodec_musicgen_32khz')
26
+ # replace this by the desired music dataset
27
+ launcher.bind_(dset='internal/music_400k_32khz')
28
+ # launch xp
29
+ launcher()
30
+ launcher({
31
+ 'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol',
32
+ 'label': 'visqol',
33
+ 'evaluate.metrics.visqol': True
34
+ })
audiocraft/grids/diffusion/4_bands_base_32khz.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Training of the 4 diffusion models described in
9
+ "From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
10
+ (paper link).
11
+ """
12
+
13
+ from ._explorers import DiffusionExplorer
14
+
15
+
16
+ @DiffusionExplorer
17
+ def explorer(launcher):
18
+ launcher.slurm_(gpus=4, partition='learnfair')
19
+
20
+ launcher.bind_({'solver': 'diffusion/default',
21
+ 'dset': 'internal/music_10k_32khz'})
22
+
23
+ with launcher.job_array():
24
+ launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
25
+ launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
26
+ launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
27
+ launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
audiocraft/grids/diffusion/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Diffusion grids."""
audiocraft/grids/diffusion/_explorers.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import treetable as tt
8
+
9
+ from .._base_explorers import BaseExplorer
10
+
11
+
12
+ class DiffusionExplorer(BaseExplorer):
13
+ eval_metrics = ["sisnr", "visqol"]
14
+
15
+ def stages(self):
16
+ return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"]
17
+
18
+ def get_grid_meta(self):
19
+ """Returns the list of Meta information to display for each XP/job.
20
+ """
21
+ return [
22
+ tt.leaf("index", align=">"),
23
+ tt.leaf("name", wrap=140),
24
+ tt.leaf("state"),
25
+ tt.leaf("sig", align=">"),
26
+ ]
27
+
28
+ def get_grid_metrics(self):
29
+ """Return the metrics that should be displayed in the tracking table.
30
+ """
31
+ return [
32
+ tt.group(
33
+ "train",
34
+ [
35
+ tt.leaf("epoch"),
36
+ tt.leaf("loss", ".3%"),
37
+ ],
38
+ align=">",
39
+ ),
40
+ tt.group(
41
+ "valid",
42
+ [
43
+ tt.leaf("loss", ".3%"),
44
+ # tt.leaf("loss_0", ".3%"),
45
+ ],
46
+ align=">",
47
+ ),
48
+ tt.group(
49
+ "valid_ema",
50
+ [
51
+ tt.leaf("loss", ".3%"),
52
+ # tt.leaf("loss_0", ".3%"),
53
+ ],
54
+ align=">",
55
+ ),
56
+ tt.group(
57
+ "evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
58
+ tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
59
+ tt.leaf("rvm_3", ".4f"), ], align=">"
60
+ ),
61
+ tt.group(
62
+ "evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
63
+ tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
64
+ tt.leaf("rvm_3", ".4f")], align=">"
65
+ ),
66
+ ]
audiocraft/grids/musicgen/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """MusicGen grids."""
audiocraft/grids/musicgen/_explorers.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import treetable as tt
10
+
11
+ from .._base_explorers import BaseExplorer
12
+
13
+
14
+ class LMExplorer(BaseExplorer):
15
+ eval_metrics: tp.List[str] = []
16
+
17
+ def stages(self) -> tp.List[str]:
18
+ return ['train', 'valid']
19
+
20
+ def get_grid_metrics(self):
21
+ """Return the metrics that should be displayed in the tracking table."""
22
+ return [
23
+ tt.group(
24
+ 'train',
25
+ [
26
+ tt.leaf('epoch'),
27
+ tt.leaf('duration', '.1f'), # duration in minutes
28
+ tt.leaf('ping'),
29
+ tt.leaf('ce', '.4f'), # cross entropy
30
+ tt.leaf("ppl", '.3f'), # perplexity
31
+ ],
32
+ align='>',
33
+ ),
34
+ tt.group(
35
+ 'valid',
36
+ [
37
+ tt.leaf('ce', '.4f'),
38
+ tt.leaf('ppl', '.3f'),
39
+ tt.leaf('best_ppl', '.3f'),
40
+ ],
41
+ align='>',
42
+ ),
43
+ ]
44
+
45
+ def process_sheep(self, sheep, history):
46
+ parts = super().process_sheep(sheep, history)
47
+
48
+ track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher']
49
+ best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()}
50
+
51
+ def comparator(mode, a, b):
52
+ return a < b if mode == 'lower' else a > b
53
+
54
+ for metrics in history:
55
+ for key, sub in metrics.items():
56
+ for metric in track_by:
57
+ # for the validation set, keep track of best metrics (ppl in this example)
58
+ # this is so we can conveniently compare metrics between runs in the grid
59
+ if key == 'valid' and metric in sub and comparator(
60
+ track_by[metric], sub[metric], best_metrics[metric]
61
+ ):
62
+ best_metrics[metric] = sub[metric]
63
+
64
+ if 'valid' in parts:
65
+ parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()})
66
+ return parts
67
+
68
+
69
+ class GenerationEvalExplorer(BaseExplorer):
70
+ eval_metrics: tp.List[str] = []
71
+
72
+ def stages(self) -> tp.List[str]:
73
+ return ['evaluate']
74
+
75
+ def get_grid_metrics(self):
76
+ """Return the metrics that should be displayed in the tracking table."""
77
+ return [
78
+ tt.group(
79
+ 'evaluate',
80
+ [
81
+ tt.leaf('epoch', '.3f'),
82
+ tt.leaf('duration', '.1f'),
83
+ tt.leaf('ping'),
84
+ tt.leaf('ce', '.4f'),
85
+ tt.leaf('ppl', '.3f'),
86
+ tt.leaf('fad', '.3f'),
87
+ tt.leaf('kld', '.3f'),
88
+ tt.leaf('text_consistency', '.3f'),
89
+ tt.leaf('chroma_cosine', '.3f'),
90
+ ],
91
+ align='>',
92
+ ),
93
+ ]
audiocraft/grids/musicgen/musicgen_base_32khz.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ._explorers import LMExplorer
8
+ from ...environment import AudioCraftEnvironment
9
+
10
+
11
+ @LMExplorer
12
+ def explorer(launcher):
13
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14
+ launcher.slurm_(gpus=32, partition=partitions)
15
+ launcher.bind_(solver='musicgen/musicgen_base_32khz')
16
+ # replace this by the desired music dataset
17
+ launcher.bind_(dset='internal/music_400k_32khz')
18
+
19
+ fsdp = {'autocast': False, 'fsdp.use': True}
20
+ medium = {'model/lm/model_scale': 'medium'}
21
+ large = {'model/lm/model_scale': 'large'}
22
+
23
+ cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
24
+ wd_low = {'conditioners.description.t5.word_dropout': 0.2}
25
+
26
+ adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
27
+
28
+ launcher.bind_(fsdp)
29
+
30
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
31
+ with launcher.job_array():
32
+ sub = launcher.bind()
33
+ sub()
34
+
35
+ launcher.slurm_(gpus=64).bind_(label='64gpus')
36
+ with launcher.job_array():
37
+ sub = launcher.bind()
38
+ sub(medium, adam)
39
+
40
+ launcher.slurm_(gpus=96).bind_(label='96gpus')
41
+ with launcher.job_array():
42
+ sub = launcher.bind()
43
+ sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
audiocraft/grids/musicgen/musicgen_base_cached_32khz.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ._explorers import LMExplorer
8
+ from ...environment import AudioCraftEnvironment
9
+
10
+
11
+ @LMExplorer
12
+ def explorer(launcher):
13
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14
+ launcher.slurm_(gpus=32, partition=partitions)
15
+ launcher.bind_(solver='musicgen/musicgen_base_32khz')
16
+ # replace this by the desired music dataset
17
+ launcher.bind_(dset='internal/music_400k_32khz')
18
+
19
+ fsdp = {'autocast': False, 'fsdp.use': True}
20
+ medium = {'model/lm/model_scale': 'medium'}
21
+ large = {'model/lm/model_scale': 'large'}
22
+
23
+ cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
24
+ wd_low = {'conditioners.description.t5.word_dropout': 0.2}
25
+
26
+ adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
27
+
28
+ # BEGINNING OF CACHE WRITING JOBS.
29
+ cache_write = {
30
+ 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
31
+ 'cache.write': True,
32
+ 'generate.every': 500,
33
+ 'evaluate.every': 500,
34
+ 'logging.log_updates': 50,
35
+ }
36
+
37
+ cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'})
38
+ cache_sub.bind_({'deadlock.use': True})
39
+ cache_sub.slurm_(gpus=8)
40
+ with launcher.job_array():
41
+ num_shards = 10 # total number of jobs running in parallel.
42
+ for shard in range(0, num_shards):
43
+ launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard})
44
+
45
+ # REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE,
46
+ # OR SUFFICIENTLY AHEAD.
47
+ return
48
+
49
+ cache = {
50
+ 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
51
+ }
52
+ launcher.bind_(fsdp, cache)
53
+
54
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
55
+ with launcher.job_array():
56
+ sub = launcher.bind()
57
+ sub()
58
+
59
+ launcher.slurm_(gpus=64).bind_(label='64gpus')
60
+ with launcher.job_array():
61
+ sub = launcher.bind()
62
+ sub(medium, adam)
63
+
64
+ launcher.slurm_(gpus=96).bind_(label='96gpus')
65
+ with launcher.job_array():
66
+ sub = launcher.bind()
67
+ sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
audiocraft/grids/musicgen/musicgen_clapemb_32khz.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ._explorers import LMExplorer
8
+ from ...environment import AudioCraftEnvironment
9
+
10
+
11
+ @LMExplorer
12
+ def explorer(launcher):
13
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14
+ launcher.slurm_(gpus=32, partition=partitions)
15
+ launcher.bind_(solver='musicgen/musicgen_base_32khz')
16
+ # replace this by the desired music dataset
17
+ launcher.bind_(dset='internal/music_400k_32khz')
18
+ launcher.bind_(conditioner='clapemb2music')
19
+
20
+ fsdp = {'autocast': False, 'fsdp.use': True}
21
+ cache_path = {'conditioners.description.clap.cache_path':
22
+ '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/clap_embed_music'}
23
+ text_wav_training_opt = {'conditioners.description.clap.text_p': 0.5}
24
+
25
+ launcher.bind_(fsdp)
26
+
27
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
28
+ with launcher.job_array():
29
+ launcher()
30
+ launcher(text_wav_training_opt)
31
+ launcher(cache_path)
32
+ launcher(cache_path, text_wav_training_opt)
audiocraft/grids/musicgen/musicgen_melody_32khz.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ._explorers import LMExplorer
8
+ from ...environment import AudioCraftEnvironment
9
+
10
+
11
+ @LMExplorer
12
+ def explorer(launcher):
13
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14
+ launcher.slurm_(gpus=32, partition=partitions)
15
+ launcher.bind_(solver='musicgen/musicgen_melody_32khz')
16
+ # replace this by the desired music dataset
17
+ launcher.bind_(dset='internal/music_400k_32khz')
18
+
19
+ fsdp = {'autocast': False, 'fsdp.use': True}
20
+ medium = {'model/lm/model_scale': 'medium'}
21
+ large = {'model/lm/model_scale': 'large'}
22
+
23
+ cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
24
+ wd_low = {'conditioners.description.t5.word_dropout': 0.2}
25
+
26
+ adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
27
+
28
+ cache_path = {'conditioners.self_wav.chroma_stem.cache_path':
29
+ '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/chroma_stem'}
30
+
31
+ # CACHE GENERATION JOBS
32
+ n_cache_gen_jobs = 4
33
+ gen_sub = launcher.slurm(gpus=1)
34
+ gen_sub.bind_(
35
+ cache_path, {
36
+ # the cache is always computed over the whole file, so duration doesn't matter here.
37
+ 'dataset.segment_duration': 2.,
38
+ 'dataset.batch_size': 8,
39
+ 'dataset.train.permutation_on_files': True, # try to not repeat files.
40
+ 'optim.epochs': 10,
41
+ 'model/lm/model_scale': 'xsmall',
42
+
43
+ })
44
+ with gen_sub.job_array():
45
+ for gen_job in range(n_cache_gen_jobs):
46
+ gen_sub({'dataset.train.shuffle_seed': gen_job})
47
+
48
+ # ACTUAL TRAINING JOBS.
49
+ launcher.bind_(fsdp)
50
+
51
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
52
+ with launcher.job_array():
53
+ sub = launcher.bind()
54
+ sub()
55
+ sub(cache_path)
56
+
57
+ launcher.slurm_(gpus=64).bind_(label='64gpus')
58
+ with launcher.job_array():
59
+ sub = launcher.bind()
60
+ sub(medium, adam)
61
+
62
+ launcher.slurm_(gpus=96).bind_(label='96gpus')
63
+ with launcher.job_array():
64
+ sub = launcher.bind()
65
+ sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Evaluation with objective metrics for the pretrained MusicGen models.
9
+ This grid takes signature from the training grid and runs evaluation-only stage.
10
+
11
+ When running the grid for the first time, please use:
12
+ REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval
13
+ and re-use the REGEN=1 option when the grid is changed to force regenerating it.
14
+
15
+ Note that you need the proper metrics external libraries setup to use all
16
+ the objective metrics activated in this grid. Refer to the README for more information.
17
+ """
18
+
19
+ import os
20
+
21
+ from ._explorers import GenerationEvalExplorer
22
+ from ...environment import AudioCraftEnvironment
23
+ from ... import train
24
+
25
+
26
+ def eval(launcher, batch_size: int = 32, eval_melody: bool = False):
27
+ opts = {
28
+ 'dset': 'audio/musiccaps_32khz',
29
+ 'solver/musicgen/evaluation': 'objective_eval',
30
+ 'execute_only': 'evaluate',
31
+ '+dataset.evaluate.batch_size': batch_size,
32
+ '+metrics.fad.tf.batch_size': 16,
33
+ }
34
+ # chroma-specific evaluation
35
+ chroma_opts = {
36
+ 'dset': 'internal/music_400k_32khz',
37
+ 'dataset.evaluate.segment_duration': 30,
38
+ 'dataset.evaluate.num_samples': 1000,
39
+ 'evaluate.metrics.chroma_cosine': True,
40
+ 'evaluate.metrics.fad': False,
41
+ 'evaluate.metrics.kld': False,
42
+ 'evaluate.metrics.text_consistency': False,
43
+ }
44
+ # binary for FAD computation: replace this path with your own path
45
+ metrics_opts = {
46
+ 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
47
+ }
48
+ opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
49
+ opt2 = {'transformer_lm.two_step_cfg': True}
50
+
51
+ sub = launcher.bind(opts)
52
+ sub.bind_(metrics_opts)
53
+
54
+ # base objective metrics
55
+ sub(opt1, opt2)
56
+
57
+ if eval_melody:
58
+ # chroma-specific metrics
59
+ sub(opt1, opt2, chroma_opts)
60
+
61
+
62
+ @GenerationEvalExplorer
63
+ def explorer(launcher):
64
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
65
+ launcher.slurm_(gpus=4, partition=partitions)
66
+
67
+ if 'REGEN' not in os.environ:
68
+ folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
69
+ with launcher.job_array():
70
+ for sig in folder.iterdir():
71
+ if not sig.is_symlink():
72
+ continue
73
+ xp = train.main.get_xp_from_sig(sig.name)
74
+ launcher(xp.argv)
75
+ return
76
+
77
+ with launcher.job_array():
78
+ musicgen_base = launcher.bind(solver="musicgen/musicgen_base_32khz")
79
+ musicgen_base.bind_({'autocast': False, 'fsdp.use': True})
80
+
81
+ # base musicgen models
82
+ musicgen_base_small = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-small'})
83
+ eval(musicgen_base_small, batch_size=128)
84
+
85
+ musicgen_base_medium = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-medium'})
86
+ musicgen_base_medium.bind_({'model/lm/model_scale': 'medium'})
87
+ eval(musicgen_base_medium, batch_size=128)
88
+
89
+ musicgen_base_large = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-large'})
90
+ musicgen_base_large.bind_({'model/lm/model_scale': 'large'})
91
+ eval(musicgen_base_large, batch_size=128)
92
+
93
+ # melody musicgen model
94
+ musicgen_melody = launcher.bind(solver="musicgen/musicgen_melody_32khz")
95
+ musicgen_melody.bind_({'autocast': False, 'fsdp.use': True})
96
+
97
+ musicgen_melody_medium = musicgen_melody.bind({'continue_from': '//pretrained/facebook/musicgen-melody'})
98
+ musicgen_melody_medium.bind_({'model/lm/model_scale': 'medium'})
99
+ eval(musicgen_melody_medium, batch_size=128, eval_melody=True)