File size: 19,060 Bytes
81d747c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 |
""" Transforms Factory
Factory methods for building image transforms for use with TIMM (PyTorch Image Models)
Hacked together by / Copyright 2019, Ross Wightman
"""
import math
from typing import Optional, Tuple, Union
import torch
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor
from timm.data.random_erasing import RandomErasing
def transforms_noaug_train(
img_size: Union[int, Tuple[int, int]] = 224,
interpolation: str = 'bilinear',
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
use_prefetcher: bool = False,
normalize: bool = True,
):
""" No-augmentation image transforms for training.
Args:
img_size: Target image size.
interpolation: Image interpolation mode.
mean: Image normalization mean.
std: Image normalization standard deviation.
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
Returns:
"""
if interpolation == 'random':
# random interpolation not supported with no-aug
interpolation = 'bilinear'
tfl = [
transforms.Resize(img_size, interpolation=str_to_interp_mode(interpolation)),
transforms.CenterCrop(img_size)
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
elif not normalize:
# when normalize disabled, converted to tensor without scaling, keep original dtype
tfl += [MaybePILToTensor()]
else:
tfl += [
MaybeToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std)
)
]
return transforms.Compose(tfl)
def transforms_imagenet_train(
img_size: Union[int, Tuple[int, int]] = 224,
scale: Optional[Tuple[float, float]] = None,
ratio: Optional[Tuple[float, float]] = None,
train_crop_mode: Optional[str] = None,
hflip: float = 0.5,
vflip: float = 0.,
color_jitter: Union[float, Tuple[float, ...]] = 0.4,
color_jitter_prob: Optional[float] = None,
force_color_jitter: bool = False,
grayscale_prob: float = 0.,
gaussian_blur_prob: float = 0.,
auto_augment: Optional[str] = None,
interpolation: str = 'random',
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
re_prob: float = 0.,
re_mode: str = 'const',
re_count: int = 1,
re_num_splits: int = 0,
use_prefetcher: bool = False,
normalize: bool = True,
separate: bool = False,
):
""" ImageNet-oriented image transforms for training.
Args:
img_size: Target image size.
train_crop_mode: Training random crop mode ('rrc', 'rkrc', 'rkrr').
scale: Random resize scale range (crop area, < 1.0 => zoom in).
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
hflip: Horizontal flip probability.
vflip: Vertical flip probability.
color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
Scalar is applied as (scalar,) * 3 (no hue).
color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug).
force_color_jitter: Force color jitter where it is normally disabled (ie with RandAugment on).
grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
auto_augment: Auto augment configuration string (see auto_augment.py).
interpolation: Image interpolation mode.
mean: Image normalization mean.
std: Image normalization standard deviation.
re_prob: Random erasing probability.
re_mode: Random erasing fill mode.
re_count: Number of random erasing regions.
re_num_splits: Control split of random erasing across batch size.
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
separate: Output transforms in 3-stage tuple.
Returns:
If separate==True, the transforms are returned as a tuple of 3 separate transforms
for use in a mixing dataset that passes
* all data through the first (primary) transform, called the 'clean' data
* a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform
"""
train_crop_mode = train_crop_mode or 'rrc'
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
if train_crop_mode in ('rkrc', 'rkrr'):
# FIXME integration of RKR is a WIP
scale = tuple(scale or (0.8, 1.00))
ratio = tuple(ratio or (0.9, 1/.9))
primary_tfl = [
ResizeKeepRatio(
img_size,
interpolation=interpolation,
random_scale_prob=0.5,
random_scale_range=scale,
random_scale_area=True, # scale compatible with RRC
random_aspect_prob=0.5,
random_aspect_range=ratio,
),
CenterCropOrPad(img_size, padding_mode='reflect')
if train_crop_mode == 'rkrc' else
RandomCropOrPad(img_size, padding_mode='reflect')
]
else:
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
primary_tfl = [
RandomResizedCropAndInterpolation(
img_size,
scale=scale,
ratio=ratio,
interpolation=interpolation,
)
]
if hflip > 0.:
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
if vflip > 0.:
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
secondary_tfl = []
disable_color_jitter = False
if auto_augment:
assert isinstance(auto_augment, str)
# color jitter is typically disabled if AA/RA on,
# this allows override without breaking old hparm cfgs
disable_color_jitter = not (force_color_jitter or '3a' in auto_augment)
if isinstance(img_size, (tuple, list)):
img_size_min = min(img_size)
else:
img_size_min = img_size
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
)
if interpolation and interpolation != 'random':
aa_params['interpolation'] = str_to_pil_interp(interpolation)
if auto_augment.startswith('rand'):
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
elif auto_augment.startswith('augmix'):
aa_params['translate_pct'] = 0.3
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
else:
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
if color_jitter is not None and not disable_color_jitter:
# color jitter is enabled when not using AA or when forced
if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
if color_jitter_prob is not None:
secondary_tfl += [
transforms.RandomApply([
transforms.ColorJitter(*color_jitter),
],
p=color_jitter_prob
)
]
else:
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
if grayscale_prob:
secondary_tfl += [transforms.RandomGrayscale(p=grayscale_prob)]
if gaussian_blur_prob:
secondary_tfl += [
transforms.RandomApply([
transforms.GaussianBlur(kernel_size=23), # hardcoded for now
],
p=gaussian_blur_prob,
)
]
final_tfl = []
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
final_tfl += [ToNumpy()]
elif not normalize:
# when normalize disable, converted to tensor without scaling, keeps original dtype
final_tfl += [MaybePILToTensor()]
else:
final_tfl += [
MaybeToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
),
]
if re_prob > 0.:
final_tfl += [
RandomErasing(
re_prob,
mode=re_mode,
max_count=re_count,
num_splits=re_num_splits,
device='cpu',
)
]
if separate:
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
else:
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
def transforms_imagenet_eval(
img_size: Union[int, Tuple[int, int]] = 224,
crop_pct: Optional[float] = None,
crop_mode: Optional[str] = None,
crop_border_pixels: Optional[int] = None,
interpolation: str = 'bilinear',
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
use_prefetcher: bool = False,
normalize: bool = True,
):
""" ImageNet-oriented image transform for evaluation and inference.
Args:
img_size: Target image size.
crop_pct: Crop percentage. Defaults to 0.875 when None.
crop_mode: Crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
crop_border_pixels: Trim a border of specified # pixels around edge of original image.
interpolation: Image interpolation mode.
mean: Image normalization mean.
std: Image normalization standard deviation.
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
Returns:
Composed transform pipeline
"""
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, (tuple, list)):
assert len(img_size) == 2
scale_size = tuple([math.floor(x / crop_pct) for x in img_size])
else:
scale_size = math.floor(img_size / crop_pct)
scale_size = (scale_size, scale_size)
tfl = []
if crop_border_pixels:
tfl += [TrimBorder(crop_border_pixels)]
if crop_mode == 'squash':
# squash mode scales each edge to 1/pct of target, then crops
# aspect ratio is not preserved, no img lost if crop_pct == 1.0
tfl += [
transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
transforms.CenterCrop(img_size),
]
elif crop_mode == 'border':
# scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop
# no image lost if crop_pct == 1.0
fill = [round(255 * v) for v in mean]
tfl += [
ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0),
CenterCropOrPad(img_size, fill=fill),
]
else:
# default crop model is center
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
if scale_size[0] == scale_size[1]:
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
tfl += [
transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation))
]
else:
# resize the shortest edge to matching target dim for non-square target
tfl += [ResizeKeepRatio(scale_size)]
tfl += [transforms.CenterCrop(img_size)]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
elif not normalize:
# when normalize disabled, converted to tensor without scaling, keeps original dtype
tfl += [MaybePILToTensor()]
else:
tfl += [
MaybeToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
),
]
return transforms.Compose(tfl)
def create_transform(
input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224,
is_training: bool = False,
no_aug: bool = False,
train_crop_mode: Optional[str] = None,
scale: Optional[Tuple[float, float]] = None,
ratio: Optional[Tuple[float, float]] = None,
hflip: float = 0.5,
vflip: float = 0.,
color_jitter: Union[float, Tuple[float, ...]] = 0.4,
color_jitter_prob: Optional[float] = None,
grayscale_prob: float = 0.,
gaussian_blur_prob: float = 0.,
auto_augment: Optional[str] = None,
interpolation: str = 'bilinear',
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
re_prob: float = 0.,
re_mode: str = 'const',
re_count: int = 1,
re_num_splits: int = 0,
crop_pct: Optional[float] = None,
crop_mode: Optional[str] = None,
crop_border_pixels: Optional[int] = None,
tf_preprocessing: bool = False,
use_prefetcher: bool = False,
normalize: bool = True,
separate: bool = False,
):
"""
Args:
input_size: Target input size (channels, height, width) tuple or size scalar.
is_training: Return training (random) transforms.
no_aug: Disable augmentation for training (useful for debug).
train_crop_mode: Training random crop mode ('rrc', 'rkrc', 'rkrr').
scale: Random resize scale range (crop area, < 1.0 => zoom in).
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
hflip: Horizontal flip probability.
vflip: Vertical flip probability.
color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
Scalar is applied as (scalar,) * 3 (no hue).
color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug).
grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
auto_augment: Auto augment configuration string (see auto_augment.py).
interpolation: Image interpolation mode.
mean: Image normalization mean.
std: Image normalization standard deviation.
re_prob: Random erasing probability.
re_mode: Random erasing fill mode.
re_count: Number of random erasing regions.
re_num_splits: Control split of random erasing across batch size.
crop_pct: Inference crop percentage (output size / resize size).
crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports
use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize.
normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
separate: Output transforms in 3-stage tuple.
Returns:
Composed transforms or tuple thereof
"""
if isinstance(input_size, (tuple, list)):
img_size = input_size[-2:]
else:
img_size = input_size
if tf_preprocessing and use_prefetcher:
assert not separate, "Separate transforms not supported for TF preprocessing"
from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform(
is_training=is_training,
size=img_size,
interpolation=interpolation,
)
else:
if is_training and no_aug:
assert not separate, "Cannot perform split augmentation with no_aug"
transform = transforms_noaug_train(
img_size,
interpolation=interpolation,
mean=mean,
std=std,
use_prefetcher=use_prefetcher,
normalize=normalize,
)
elif is_training:
transform = transforms_imagenet_train(
img_size,
train_crop_mode=train_crop_mode,
scale=scale,
ratio=ratio,
hflip=hflip,
vflip=vflip,
color_jitter=color_jitter,
color_jitter_prob=color_jitter_prob,
grayscale_prob=grayscale_prob,
gaussian_blur_prob=gaussian_blur_prob,
auto_augment=auto_augment,
interpolation=interpolation,
mean=mean,
std=std,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
use_prefetcher=use_prefetcher,
normalize=normalize,
separate=separate,
)
else:
assert not separate, "Separate transforms not supported for validation preprocessing"
transform = transforms_imagenet_eval(
img_size,
interpolation=interpolation,
mean=mean,
std=std,
crop_pct=crop_pct,
crop_mode=crop_mode,
crop_border_pixels=crop_border_pixels,
use_prefetcher=use_prefetcher,
normalize=normalize,
)
return transform
|