import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange class SimpleSAFM(nn.Module): def __init__(self, dim): super().__init__() self.proj = nn.Conv2d(dim, dim, 3, 1, 1, bias=False) self.dwconv = nn.Conv2d(dim//2, dim//2, 3, 1, 1, groups=dim//2, bias=False) self.out = nn.Conv2d(dim, dim, 1, 1, 0, bias=False) self.act = nn.GELU() def forward(self, x): h, w = x.size()[-2:] x0, x1 = self.proj(x).chunk(2, dim=1) x2 = F.adaptive_max_pool2d(x0, (h//8, w//8)) x2 = self.dwconv(x2) x2 = F.interpolate(x2, size=(h, w), mode='bilinear') x2 = self.act(x2) * x0 x = torch.cat([x1, x2], dim=1) x = self.out(self.act(x)) return x class CCM(nn.Module): def __init__(self, dim, ffn_scale): super().__init__() self.conv = nn.Sequential( nn.Conv2d(dim, int(dim*ffn_scale), 3, 1, 1, bias=False), nn.GELU(), nn.Conv2d(int(dim*ffn_scale), dim, 1, 1, 0, bias=False) ) def forward(self, x): return self.conv(x) class AttBlock(nn.Module): def __init__(self, dim, ffn_scale): super().__init__() self.conv1 = SimpleSAFM(dim) self.conv2 = CCM(dim, ffn_scale) def forward(self, x): out = self.conv1(x) out = self.conv2(out) return out class SAFMNPP(nn.Module): def __init__(self, dim=32, n_blocks=2, ffn_scale=1.5, upscaling_factor=4): super().__init__() self.scale = upscaling_factor self.to_feat = nn.Conv2d(3, dim, 3, 1, 1, bias=False) self.feats = nn.Sequential(*[AttBlock(dim, ffn_scale) for _ in range(n_blocks)]) self.to_img = nn.Sequential( nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1, bias=False), nn.PixelShuffle(upscaling_factor) ) def forward(self, x): b = x.shape[0] x = rearrange(x, 'b t c h w -> (b t) c h w') x = self.to_feat(x) x = self.feats(x) + x x = self.to_img(x) x = rearrange(x, '(b t) c h w -> b t c h w', b = b) return x if __name__== '__main__': #############Test Model Complexity ############# # import time from fvcore.nn import flop_count_table, FlopCountAnalysis, ActivationCountAnalysis from tqdm import tqdm device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') scale = 4 h, w = 3840, 2160 # scale = 3 # h, w = 1920, 1080 x = torch.randn(1, 30, 3, h// scale, w // scale) model = SAFMNPP(upscaling_factor=scale) model.load_state_dict(torch.load('light_safmnpp.pth')['params'], strict=True) # output = model(x) print(model) # print(flop_count_table(FlopCountAnalysis(model, x), activations=ActivationCountAnalysis(model, x))) # print(output.shape) # num_frame = 30 # clip = 5 # torch.cuda.current_device() # torch.cuda.empty_cache() # torch.backends.cudnn.benchmark = False # start = torch.cuda.Event(enable_timing=True) # end = torch.cuda.Event(enable_timing=True) # runtime = 0 # dummy_input = torch.randn((1, num_frame, 3, h // scale, w // scale)).to(device) # # warm_up # model.eval().to(device) # with torch.no_grad(): # for _ in tqdm(range(clip)): # _ = model(dummy_input) # for _ in tqdm(range(clip)): # start.record() # _ = model(dummy_input) # end.record() # torch.cuda.synchronize() # runtime += start.elapsed_time(end) # per_frame_time = runtime / (num_frame * clip) # print(f'{model.__class__.__name__} {num_frame * clip} Number Frames x{scale}SR Per Frame Time: {per_frame_time:.6f} ms') # print(f'{model.__class__.__name__} x{scale}SR FPS: {(1000 / per_frame_time):.6f} FPS')