test / tools /augment_imagenetc.py
Tu Bui
first commit
6142a25
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
wrapper for imagenet-c transformations
@author: Tu Bui @surrey.ac.uk
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import random
import numpy as np
from PIL import Image
from imagenet_c import corrupt, corruption_dict
class IdentityAugment(object):
def __call__(self, x):
return x
def __repr__(self):
s = f'()'
return self.__class__.__name__ + s
class RandomImagenetC(object):
# transform id 5 (motion blur) and 7 (snow) requires WandImage which is not fork-safe, while id 4 (glass blur) and 6 (zoom blur) are super slow thus we move it to validation (unseen), 12 (elastic transform) is non realistic
methods = {'train': np.array([0,1,2,3,8,9,10,11,13,14,15, 16, 17, 18]),#np.arange(15),
'val': np.array([4, 5, 6, 7, 12]),
'test': np.array([0,1,2,3,8,9,10,11,13,14,15, 16, 17, 18])
}
method_names = list(corruption_dict.keys())
def __init__(self, min_severity=1, max_severity=5, phase='all', p=1.0,n=19):
assert phase in ['train', 'val', 'test', 'all'], ValueError(f'{phase} not recognised. Must be one of [train, val, all]')
if phase == 'all':
self.corrupt_ids = np.concatenate(list(self.methods.values()))
else:
self.corrupt_ids = self.methods[phase]
self.corrupt_ids = self.corrupt_ids[:n] # first n tforms
self.phase = phase
self.severity = np.arange(min_severity, max_severity+1)
self.p = p # probability to apply a transformation
def __call__(self, x, corrupt_id=None, corrupt_strength=None):
# input: x PIL image
if corrupt_id is None:
if len(self.corrupt_ids)==0: # do nothing
return x
corrupt_id = np.random.choice(self.corrupt_ids)
else:
assert corrupt_id in range(19)
severity = np.random.choice(self.severity) if corrupt_strength is None else corrupt_strength
assert severity in self.severity, f'Error! Corrupt strength {severity} isnt supported.'
if np.random.rand() < self.p:
org_size = x.size
x = np.asarray(x.convert('RGB').resize((224, 224), Image.BILINEAR))[:,:,::-1]
x = corrupt(x, severity, corruption_number=corrupt_id)
x = Image.fromarray(x[:,:,::-1])
if x.size != org_size:
x = x.resize(org_size, Image.BILINEAR)
return x
def transform_with_fixed_severity(self, x, severity, corrupt_id=None):
if corrupt_id is None:
corrupt_id = np.random.choice(self.corrupt_ids)
else:
assert corrupt_id in self.corrupt_ids
assert severity > 0 and severity < 6
org_size = x.size
x = np.asarray(x.convert('RGB').resize((224, 224), Image.BILINEAR))[:,:,::-1]
x = corrupt(x, severity, corruption_number=corrupt_id)
x = Image.fromarray(x[:,:,::-1])
if x.size != org_size:
x = x.resize(org_size, Image.BILINEAR)
return x
def __repr__(self):
s = f'(severity={self.severity}, phase={self.phase}, p={self.p},ids={self.corrupt_ids})'
return self.__class__.__name__ + s
class NoiseResidual(object):
def __init__(self, k=16):
self.k = k
def __call__(self, x):
h, w = x.height, x.width
x1 = x.resize((w//self.k,h//self.k), Image.BILINEAR).resize((w, h), Image.BILINEAR)
x1 = np.abs(np.array(x).astype(np.float32) - np.array(x1).astype(np.float32))
x1 = (x1 - x1.min())/(x1.max() - x1.min() + np.finfo(np.float32).eps)
x1 = Image.fromarray((x1*255).astype(np.uint8))
return x1
def __repr__(self):
s = f'(k={self.k}'
return self.__class__.__name__ + s
def get_transforms(img_mean=[0.5, 0.5, 0.5], img_std=[0.5, 0.5, 0.5], rsize=256, csize=224, pertubation=True, dct=False, residual=False, max_c=19):
from torchvision import transforms
prep = transforms.Compose([
transforms.Resize(rsize),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(csize)])
if pertubation:
pertubation_train = RandomImagenetC(max_severity=5, phase='train', p=0.95,n=max_c)
pertubation_val = RandomImagenetC(max_severity=5, phase='train', p=1.0,n=max_c)
pertubation_test = RandomImagenetC(max_severity=5, phase='val', p=1.0,n=max_c)
else:
pertubation_train = pertubation_val = pertubation_test = IdentityAugment()
if dct:
from .image_tools import DCT
norm = [
DCT(),
transforms.ToTensor(),
transforms.Normalize(mean=img_mean, std=img_std)]
else:
norm = [
transforms.ToTensor(),
transforms.Normalize(mean=img_mean, std=img_std)]
if residual:
norm.insert(0, NoiseResidual())
preprocess = {
'train': [prep, pertubation_train, transforms.Compose(norm)],
'val': [prep, pertubation_val, transforms.Compose(norm)],
'test_unseen': [prep, pertubation_test, transforms.Compose(norm)],
'clean': transforms.Compose([transforms.Resize(csize)] + norm)
}
return preprocess
# ## example
# from PIL import Image
# import numpy as np
# import time
# from imagenet_c import corrupt, corruption_dict
# im = Image.open('/vol/research/tubui1/projects/gan_prov/gan_models/stargan2/test.jpg').convert('RGB').resize((224,224), Image.BILINEAR)
# im.save('original.jpg')
# im = np.array(im)[:,:,::-1] # BRG
# t = np.zeros(19)
# for i, key in enumerate(corruption_dict.keys()):
# begin = time.time()
# for j in range(10):
# out = corrupt(im, 5, corruption_number=i)
# end = time.time()
# t[i] = end-begin
# # Image.fromarray(out[:,:,::-1]).save(f'imc_{key}.jpg')
# print(f'{i} - {key}: {end-begin}')
# for i,k in enumerate(corruption_dict.keys()):
# print(i, k, t[i])