Spaces:
Running
on
T4
Running
on
T4
import os, sys | |
import torch | |
# Import files from same folder | |
root_path = os.path.abspath('.') | |
sys.path.append(root_path) | |
from opt import opt | |
from architecture.rrdb import RRDBNet | |
from architecture.grl import GRL | |
from architecture.dat import DAT | |
from architecture.swinir import SwinIR | |
from architecture.cunet import UNet_Full | |
def load_rrdb(generator_weight_PATH, scale, print_options=False): | |
''' A simpler API to load RRDB model from Real-ESRGAN | |
Args: | |
generator_weight_PATH (str): The path to the weight | |
scale (int): the scaling factor | |
print_options (bool): whether to print options to show what kinds of setting is used | |
Returns: | |
generator (torch): the generator instance of the model | |
''' | |
# Load the checkpoint | |
checkpoint_g = torch.load(generator_weight_PATH) | |
# Find the generator weight | |
if 'params_ema' in checkpoint_g: | |
# For official ESRNET/ESRGAN weight | |
weight = checkpoint_g['params_ema'] | |
generator = RRDBNet(3, 3, scale=scale) # Default blocks num is 6 | |
elif 'params' in checkpoint_g: | |
# For official ESRNET/ESRGAN weight | |
weight = checkpoint_g['params'] | |
generator = RRDBNet(3, 3, scale=scale) | |
elif 'model_state_dict' in checkpoint_g: | |
# For my personal trained weight | |
weight = checkpoint_g['model_state_dict'] | |
generator = RRDBNet(3, 3, scale=scale) | |
else: | |
print("This weight is not supported") | |
os._exit(0) | |
# Handle torch.compile weight key rename | |
old_keys = [key for key in weight] | |
for old_key in old_keys: | |
if old_key[:10] == "_orig_mod.": | |
new_key = old_key[10:] | |
weight[new_key] = weight[old_key] | |
del weight[old_key] | |
generator.load_state_dict(weight) | |
generator = generator.eval().cuda() | |
# Print options to show what kinds of setting is used | |
if print_options: | |
if 'opt' in checkpoint_g: | |
for key in checkpoint_g['opt']: | |
value = checkpoint_g['opt'][key] | |
print(f'{key} : {value}') | |
return generator | |
def load_cunet(generator_weight_PATH, scale, print_options=False): | |
''' A simpler API to load CUNET model from Real-CUGAN | |
Args: | |
generator_weight_PATH (str): The path to the weight | |
scale (int): the scaling factor | |
print_options (bool): whether to print options to show what kinds of setting is used | |
Returns: | |
generator (torch): the generator instance of the model | |
''' | |
# This func is deprecated now | |
if scale != 2: | |
raise NotImplementedError("We only support 2x in CUNET") | |
# Load the checkpoint | |
checkpoint_g = torch.load(generator_weight_PATH) | |
# Find the generator weight | |
if 'model_state_dict' in checkpoint_g: | |
# For my personal trained weight | |
weight = checkpoint_g['model_state_dict'] | |
loss = checkpoint_g["lowest_generator_weight"] | |
if "iteration" in checkpoint_g: | |
iteration = checkpoint_g["iteration"] | |
else: | |
iteration = "NAN" | |
generator = UNet_Full() | |
# generator = torch.compile(generator)# torch.compile | |
print(f"the generator weight is {loss} at iteration {iteration}") | |
else: | |
print("This weight is not supported") | |
os._exit(0) | |
# Handle torch.compile weight key rename | |
old_keys = [key for key in weight] | |
for old_key in old_keys: | |
if old_key[:10] == "_orig_mod.": | |
new_key = old_key[10:] | |
weight[new_key] = weight[old_key] | |
del weight[old_key] | |
generator.load_state_dict(weight) | |
generator = generator.eval().cuda() | |
# Print options to show what kinds of setting is used | |
if print_options: | |
if 'opt' in checkpoint_g: | |
for key in checkpoint_g['opt']: | |
value = checkpoint_g['opt'][key] | |
print(f'{key} : {value}') | |
return generator | |
def load_grl(generator_weight_PATH, scale=4): | |
''' A simpler API to load GRL model | |
Args: | |
generator_weight_PATH (str): The path to the weight | |
scale (int): Scale Factor (Usually Set as 4) | |
Returns: | |
generator (torch): the generator instance of the model | |
''' | |
# Load the checkpoint | |
checkpoint_g = torch.load(generator_weight_PATH) | |
# Find the generator weight | |
if 'model_state_dict' in checkpoint_g: | |
weight = checkpoint_g['model_state_dict'] | |
# GRL tiny model (Note: tiny2 version) | |
generator = GRL( | |
upscale = scale, | |
img_size = 64, | |
window_size = 8, | |
depths = [4, 4, 4, 4], | |
embed_dim = 64, | |
num_heads_window = [2, 2, 2, 2], | |
num_heads_stripe = [2, 2, 2, 2], | |
mlp_ratio = 2, | |
qkv_proj_type = "linear", | |
anchor_proj_type = "avgpool", | |
anchor_window_down_factor = 2, | |
out_proj_type = "linear", | |
conv_type = "1conv", | |
upsampler = "nearest+conv", # Change | |
).cuda() | |
else: | |
print("This weight is not supported") | |
os._exit(0) | |
generator.load_state_dict(weight) | |
generator = generator.eval().cuda() | |
num_params = 0 | |
for p in generator.parameters(): | |
if p.requires_grad: | |
num_params += p.numel() | |
print(f"Number of parameters {num_params / 10 ** 6: 0.2f}") | |
return generator | |
def load_dat(generator_weight_PATH, scale=4): | |
# Load the checkpoint | |
checkpoint_g = torch.load(generator_weight_PATH) | |
# Find the generator weight | |
if 'model_state_dict' in checkpoint_g: | |
weight = checkpoint_g['model_state_dict'] | |
# DAT small model in default | |
generator = DAT(upscale = 4, | |
in_chans = 3, | |
img_size = 64, | |
img_range = 1., | |
depth = [6, 6, 6, 6, 6, 6], | |
embed_dim = 180, | |
num_heads = [6, 6, 6, 6, 6, 6], | |
expansion_factor = 2, | |
resi_connection = '1conv', | |
split_size = [8, 16], | |
upsampler = 'pixelshuffledirect', | |
).cuda() | |
else: | |
print("This weight is not supported") | |
os._exit(0) | |
generator.load_state_dict(weight) | |
generator = generator.eval().cuda() | |
num_params = 0 | |
for p in generator.parameters(): | |
if p.requires_grad: | |
num_params += p.numel() | |
print(f"Number of parameters {num_params / 10 ** 6: 0.2f}") | |
return generator | |