APISR / test_code /test_utils.py
HikariDawn's picture
feat: DAT and comparison
9bf54b1
import os, sys
import torch
# Import files from same folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from opt import opt
from architecture.rrdb import RRDBNet
from architecture.grl import GRL
from architecture.dat import DAT
from architecture.swinir import SwinIR
from architecture.cunet import UNet_Full
def load_rrdb(generator_weight_PATH, scale, print_options=False):
''' A simpler API to load RRDB model from Real-ESRGAN
Args:
generator_weight_PATH (str): The path to the weight
scale (int): the scaling factor
print_options (bool): whether to print options to show what kinds of setting is used
Returns:
generator (torch): the generator instance of the model
'''
# Load the checkpoint
checkpoint_g = torch.load(generator_weight_PATH)
# Find the generator weight
if 'params_ema' in checkpoint_g:
# For official ESRNET/ESRGAN weight
weight = checkpoint_g['params_ema']
generator = RRDBNet(3, 3, scale=scale) # Default blocks num is 6
elif 'params' in checkpoint_g:
# For official ESRNET/ESRGAN weight
weight = checkpoint_g['params']
generator = RRDBNet(3, 3, scale=scale)
elif 'model_state_dict' in checkpoint_g:
# For my personal trained weight
weight = checkpoint_g['model_state_dict']
generator = RRDBNet(3, 3, scale=scale)
else:
print("This weight is not supported")
os._exit(0)
# Handle torch.compile weight key rename
old_keys = [key for key in weight]
for old_key in old_keys:
if old_key[:10] == "_orig_mod.":
new_key = old_key[10:]
weight[new_key] = weight[old_key]
del weight[old_key]
generator.load_state_dict(weight)
generator = generator.eval().cuda()
# Print options to show what kinds of setting is used
if print_options:
if 'opt' in checkpoint_g:
for key in checkpoint_g['opt']:
value = checkpoint_g['opt'][key]
print(f'{key} : {value}')
return generator
def load_cunet(generator_weight_PATH, scale, print_options=False):
''' A simpler API to load CUNET model from Real-CUGAN
Args:
generator_weight_PATH (str): The path to the weight
scale (int): the scaling factor
print_options (bool): whether to print options to show what kinds of setting is used
Returns:
generator (torch): the generator instance of the model
'''
# This func is deprecated now
if scale != 2:
raise NotImplementedError("We only support 2x in CUNET")
# Load the checkpoint
checkpoint_g = torch.load(generator_weight_PATH)
# Find the generator weight
if 'model_state_dict' in checkpoint_g:
# For my personal trained weight
weight = checkpoint_g['model_state_dict']
loss = checkpoint_g["lowest_generator_weight"]
if "iteration" in checkpoint_g:
iteration = checkpoint_g["iteration"]
else:
iteration = "NAN"
generator = UNet_Full()
# generator = torch.compile(generator)# torch.compile
print(f"the generator weight is {loss} at iteration {iteration}")
else:
print("This weight is not supported")
os._exit(0)
# Handle torch.compile weight key rename
old_keys = [key for key in weight]
for old_key in old_keys:
if old_key[:10] == "_orig_mod.":
new_key = old_key[10:]
weight[new_key] = weight[old_key]
del weight[old_key]
generator.load_state_dict(weight)
generator = generator.eval().cuda()
# Print options to show what kinds of setting is used
if print_options:
if 'opt' in checkpoint_g:
for key in checkpoint_g['opt']:
value = checkpoint_g['opt'][key]
print(f'{key} : {value}')
return generator
def load_grl(generator_weight_PATH, scale=4):
''' A simpler API to load GRL model
Args:
generator_weight_PATH (str): The path to the weight
scale (int): Scale Factor (Usually Set as 4)
Returns:
generator (torch): the generator instance of the model
'''
# Load the checkpoint
checkpoint_g = torch.load(generator_weight_PATH)
# Find the generator weight
if 'model_state_dict' in checkpoint_g:
weight = checkpoint_g['model_state_dict']
# GRL tiny model (Note: tiny2 version)
generator = GRL(
upscale = scale,
img_size = 64,
window_size = 8,
depths = [4, 4, 4, 4],
embed_dim = 64,
num_heads_window = [2, 2, 2, 2],
num_heads_stripe = [2, 2, 2, 2],
mlp_ratio = 2,
qkv_proj_type = "linear",
anchor_proj_type = "avgpool",
anchor_window_down_factor = 2,
out_proj_type = "linear",
conv_type = "1conv",
upsampler = "nearest+conv", # Change
).cuda()
else:
print("This weight is not supported")
os._exit(0)
generator.load_state_dict(weight)
generator = generator.eval().cuda()
num_params = 0
for p in generator.parameters():
if p.requires_grad:
num_params += p.numel()
print(f"Number of parameters {num_params / 10 ** 6: 0.2f}")
return generator
def load_dat(generator_weight_PATH, scale=4):
# Load the checkpoint
checkpoint_g = torch.load(generator_weight_PATH)
# Find the generator weight
if 'model_state_dict' in checkpoint_g:
weight = checkpoint_g['model_state_dict']
# DAT small model in default
generator = DAT(upscale = 4,
in_chans = 3,
img_size = 64,
img_range = 1.,
depth = [6, 6, 6, 6, 6, 6],
embed_dim = 180,
num_heads = [6, 6, 6, 6, 6, 6],
expansion_factor = 2,
resi_connection = '1conv',
split_size = [8, 16],
upsampler = 'pixelshuffledirect',
).cuda()
else:
print("This weight is not supported")
os._exit(0)
generator.load_state_dict(weight)
generator = generator.eval().cuda()
num_params = 0
for p in generator.parameters():
if p.requires_grad:
num_params += p.numel()
print(f"Number of parameters {num_params / 10 ** 6: 0.2f}")
return generator