File size: 2,677 Bytes
d72b2c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Utility to export a training checkpoint to a lightweight release checkpoint.
"""

from pathlib import Path
import typing as tp

from omegaconf import OmegaConf
import torch

from audiocraft import __version__


def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
    """Export only the best state from the given EnCodec checkpoint. This
    should be used if you trained your own EnCodec model.
    """
    pkg = torch.load(checkpoint_path, 'cpu')
    new_pkg = {
        'best_state': pkg['best_state']['model'],
        'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
        'version': __version__,
        'exported': True,
    }
    Path(out_file).parent.mkdir(exist_ok=True, parents=True)
    torch.save(new_pkg, out_file)
    return out_file


def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]):
    """Export a compression model (potentially EnCodec) from a pretrained model.
    This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
    Do not include the //pretrained/ prefix. For instance if you trained a model
    with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.

    In that case, this will not actually include a copy of the model, simply the reference
    to the model used.
    """
    if Path(pretrained_encodec).exists():
        pkg = torch.load(pretrained_encodec)
        assert 'best_state' in pkg
        assert 'xp.cfg' in pkg
        assert 'version' in pkg
        assert 'exported' in pkg
    else:
        pkg = {
            'pretrained': pretrained_encodec,
            'exported': True,
            'version': __version__,
        }
    Path(out_file).parent.mkdir(exist_ok=True, parents=True)
    torch.save(pkg, out_file)


def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
    """Export only the best state from the given MusicGen or AudioGen checkpoint.
    """
    pkg = torch.load(checkpoint_path, 'cpu')
    if pkg['fsdp_best_state']:
        best_state = pkg['fsdp_best_state']['model']
    else:
        assert pkg['best_state']
        best_state = pkg['best_state']['model']
    new_pkg = {
        'best_state': best_state,
        'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
        'version': __version__,
        'exported': True,
    }

    Path(out_file).parent.mkdir(exist_ok=True, parents=True)
    torch.save(new_pkg, out_file)
    return out_file