|
|
|
import argparse |
|
import torch |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description="Release an OpenNMT-py model for inference" |
|
) |
|
parser.add_argument("--model", "-m", help="The model path", required=True) |
|
parser.add_argument("--output", "-o", help="The output path", required=True) |
|
parser.add_argument( |
|
"--format", |
|
choices=["pytorch", "ctranslate2"], |
|
default="pytorch", |
|
help="The format of the released model", |
|
) |
|
parser.add_argument( |
|
"--quantization", |
|
"-q", |
|
choices=["int8", "int16", "float16", "int8_float16"], |
|
default=None, |
|
help="Quantization type for CT2 model.", |
|
) |
|
opt = parser.parse_args() |
|
|
|
model = torch.load(opt.model, map_location=torch.device("cpu")) |
|
if opt.format == "pytorch": |
|
model["optim"] = None |
|
torch.save(model, opt.output) |
|
elif opt.format == "ctranslate2": |
|
import ctranslate2 |
|
|
|
converter = ctranslate2.converters.OpenNMTPyConverter(opt.model) |
|
converter.convert(opt.output, force=True, quantization=opt.quantization) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|