|
import torch |
|
import numpy as np |
|
import skimage.metrics |
|
import lpips |
|
from PIL import Image |
|
from .sifid import SIFID |
|
|
|
|
|
def resize_array(x, size=256): |
|
""" |
|
Resize image array to given size. |
|
Args: |
|
x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255]. |
|
size (int): Size of output image. |
|
Returns: |
|
(np.ndarray): Image array of shape (N, H, W, C) in range [0, 255]. |
|
""" |
|
if x.shape[1] != size: |
|
x = [Image.fromarray(x[i]).resize((size, size), resample=Image.BILINEAR) for i in range(x.shape[0])] |
|
x = np.array([np.array(i) for i in x]) |
|
return x |
|
|
|
|
|
def resize_tensor(x, size=256): |
|
""" |
|
Resize image tensor to given size. |
|
Args: |
|
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
|
size (int): Size of output image. |
|
Returns: |
|
(torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
|
""" |
|
if x.shape[2] != size: |
|
x = torch.nn.functional.interpolate(x, size=(size, size), mode='bilinear', align_corners=False) |
|
return x |
|
|
|
|
|
def normalise(x): |
|
""" |
|
Normalise image array to range [-1, 1] and tensor. |
|
Args: |
|
x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255]. |
|
Returns: |
|
(torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
|
""" |
|
x = x.astype(np.float32) |
|
x = x / 255 |
|
x = (x - 0.5) / 0.5 |
|
x = torch.from_numpy(x) |
|
x = x.permute(0, 3, 1, 2) |
|
return x |
|
|
|
|
|
def unormalise(x, vrange=[-1, 1]): |
|
""" |
|
Unormalise image tensor to range [0, 255] and RGB array. |
|
Args: |
|
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
|
Returns: |
|
(np.ndarray): Image array of shape (N, H, W, C) in range [0, 255]. |
|
""" |
|
x = (x - vrange[0])/(vrange[1] - vrange[0]) |
|
x = x * 255 |
|
x = x.permute(0, 2, 3, 1) |
|
x = x.cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
|
|
def compute_mse(x, y): |
|
""" |
|
Compute mean squared error between two image arrays. |
|
Args: |
|
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255]. |
|
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255]. |
|
Returns: |
|
(1darray): Mean squared error. |
|
""" |
|
return np.square(x - y).reshape(x.shape[0], -1).mean(axis=1) |
|
|
|
|
|
def compute_psnr(x, y): |
|
""" |
|
Compute peak signal-to-noise ratio between two images. |
|
Args: |
|
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255]. |
|
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255]. |
|
Returns: |
|
(float): Peak signal-to-noise ratio. |
|
""" |
|
return 10 * np.log10(255 ** 2 / compute_mse(x, y)) |
|
|
|
|
|
def compute_ssim(x, y): |
|
""" |
|
Compute structural similarity index between two images. |
|
Args: |
|
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255]. |
|
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255]. |
|
Returns: |
|
(float): Structural similarity index. |
|
""" |
|
return np.array([skimage.metrics.structural_similarity(xi, yi, channel_axis=2, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=255) for xi, yi in zip(x, y)]) |
|
|
|
|
|
def compute_lpips(x, y, net='alex'): |
|
""" |
|
Compute LPIPS between two images. |
|
Args: |
|
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
|
y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
|
Returns: |
|
(float): LPIPS. |
|
""" |
|
lpips_fn = lpips.LPIPS(net=net, verbose=False).cuda() if isinstance(net, str) else net |
|
x, y = x.cuda(), y.cuda() |
|
return lpips_fn(x, y).detach().cpu().numpy().squeeze() |
|
|
|
|
|
def compute_sifid(x, y, net=None): |
|
""" |
|
Compute SIFID between two images. |
|
Args: |
|
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
|
y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
|
Returns: |
|
(float): SIFID. |
|
""" |
|
fn = SIFID() if net is None else net |
|
out = [fn(xi, yi) for xi, yi in zip(x, y)] |
|
return np.array(out) |