Baraaqasem's picture
Upload 49 files
413d4d0 verified
raw
history blame
2.71 kB
import importlib
import torch
import json
import math
import os
import numpy as np
import torch.nn.functional as F
def new_module(config):
'''in config:
"target": module type
"params": dict of params'''
if type(config) == str:
with open(config, 'r') as file:
config = json.load(file)
assert type(config) == dict
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
module, cls = config.get('target').rsplit(".", 1)
model = getattr(importlib.import_module(module, package=__package__), cls)(**config.get("params", dict()))
return model
def load_ckpt(model, path):
sd = torch.load(path, map_location="cpu")['module']
model.load_state_dict(sd, strict=False)
return model
def load_default_HVQVAE():
config = {
"target": "..vqvae.HVQVAE",
"params": {
"levels": 3,
"embedding_dim": 256,
"codebook_scale": 1,
"down_sampler_configs": [
{
"target": "..vqvae.ResidualDownSample",
"params": {
"in_channels": 256
}
},
{
"target": "..vqvae.ResidualDownSample",
"params": {
"in_channels": 256
}
}
],
"enc_config": {
"target": "..vqvae.Encoder",
"params": {
"num_res_blocks": 2,
"channels_mult": [1,2,4]
}
},
"quantize_config": {
"target": "..vqvae.VectorQuantizeEMA",
"params": {
"hidden_dim": 256,
"embedding_dim": 256,
"n_embed": 20000,
"training_loc": False
}
},
"dec_configs": [
{
"target": "..vqvae.Decoder",
"params": {
"channels_mult": [1,1,1,2,4]
}
},
{
"target": "..vqvae.Decoder",
"params": {
"channels_mult": [1,1,2,4]
}
},
{
"target": "..vqvae.Decoder",
"params": {
"channels_mult": [1,2,4]
}
}
]
}
}
return new_module(config)
if __name__ == '__main__':
pass