Spaces:
Runtime error
Runtime error
import importlib | |
import torch | |
import json | |
import math | |
import os | |
import numpy as np | |
import torch.nn.functional as F | |
def new_module(config): | |
'''in config: | |
"target": module type | |
"params": dict of params''' | |
if type(config) == str: | |
with open(config, 'r') as file: | |
config = json.load(file) | |
assert type(config) == dict | |
if not "target" in config: | |
raise KeyError("Expected key `target` to instantiate.") | |
module, cls = config.get('target').rsplit(".", 1) | |
model = getattr(importlib.import_module(module, package=__package__), cls)(**config.get("params", dict())) | |
return model | |
def load_ckpt(model, path): | |
sd = torch.load(path, map_location="cpu")['module'] | |
model.load_state_dict(sd, strict=False) | |
return model | |
def load_default_HVQVAE(): | |
config = { | |
"target": "..vqvae.HVQVAE", | |
"params": { | |
"levels": 3, | |
"embedding_dim": 256, | |
"codebook_scale": 1, | |
"down_sampler_configs": [ | |
{ | |
"target": "..vqvae.ResidualDownSample", | |
"params": { | |
"in_channels": 256 | |
} | |
}, | |
{ | |
"target": "..vqvae.ResidualDownSample", | |
"params": { | |
"in_channels": 256 | |
} | |
} | |
], | |
"enc_config": { | |
"target": "..vqvae.Encoder", | |
"params": { | |
"num_res_blocks": 2, | |
"channels_mult": [1,2,4] | |
} | |
}, | |
"quantize_config": { | |
"target": "..vqvae.VectorQuantizeEMA", | |
"params": { | |
"hidden_dim": 256, | |
"embedding_dim": 256, | |
"n_embed": 20000, | |
"training_loc": False | |
} | |
}, | |
"dec_configs": [ | |
{ | |
"target": "..vqvae.Decoder", | |
"params": { | |
"channels_mult": [1,1,1,2,4] | |
} | |
}, | |
{ | |
"target": "..vqvae.Decoder", | |
"params": { | |
"channels_mult": [1,1,2,4] | |
} | |
}, | |
{ | |
"target": "..vqvae.Decoder", | |
"params": { | |
"channels_mult": [1,2,4] | |
} | |
} | |
] | |
} | |
} | |
return new_module(config) | |
if __name__ == '__main__': | |
pass |