SEMat / data /dim_dataset.py
XiaRho's picture
Init
8b4c6c7 verified
raw
history blame
59.4 kB
'''
Dataloader to process Adobe Image Matting Dataset.
From GCA_Matting(https://github.com/Yaoyi-Li/GCA-Matting/tree/master/dataloader)
'''
import os
import glob
import logging
import os.path as osp
import functools
import numpy as np
import torch
import cv2
import math
import numbers
import random
import pickle
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torchvision import transforms
from easydict import EasyDict
from detectron2.utils.logger import setup_logger
from detectron2.utils import comm
from detectron2.data import build_detection_test_loader
import torchvision.transforms.functional
import json
from PIL import Image
from detectron2.evaluation.evaluator import DatasetEvaluator
from collections import defaultdict
from data.evaluate import compute_sad_loss, compute_mse_loss, compute_mad_loss, compute_gradient_loss, compute_connectivity_error
# Base default config
CONFIG = EasyDict({})
# Model config
CONFIG.model = EasyDict({})
# one-hot or class, choice: [3, 1]
CONFIG.model.trimap_channel = 1
# Dataloader config
CONFIG.data = EasyDict({})
# feed forward image size (untested)
CONFIG.data.crop_size = 512
# composition of two foregrounds, affine transform, crop and HSV jitter
CONFIG.data.cutmask_prob = 0.25
CONFIG.data.augmentation = True
CONFIG.data.random_interp = True
class Prefetcher():
"""
Modified from the data_prefetcher in https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py
"""
def __init__(self, loader):
self.orig_loader = loader
self.stream = torch.cuda.Stream()
self.next_sample = None
def preload(self):
try:
self.next_sample = next(self.loader)
except StopIteration:
self.next_sample = None
return
with torch.cuda.stream(self.stream):
for key, value in self.next_sample.items():
if isinstance(value, torch.Tensor):
self.next_sample[key] = value.cuda(non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
sample = self.next_sample
if sample is not None:
for key, value in sample.items():
if isinstance(value, torch.Tensor):
sample[key].record_stream(torch.cuda.current_stream())
self.preload()
else:
# throw stop exception if there is no more data to perform as a default dataloader
raise StopIteration("No samples in loader. example: `iterator = iter(Prefetcher(loader)); "
"data = next(iterator)`")
return sample
def __iter__(self):
self.loader = iter(self.orig_loader)
self.preload()
return self
class ImageFile(object):
def __init__(self, phase='train'):
self.phase = phase
self.rng = np.random.RandomState(0)
def _get_valid_names(self, *dirs, shuffle=True):
name_sets = [self._get_name_set(d) for d in dirs]
def _join_and(a, b):
return a & b
valid_names = list(functools.reduce(_join_and, name_sets))
if shuffle:
self.rng.shuffle(valid_names)
return valid_names
@staticmethod
def _get_name_set(dir_name):
path_list = glob.glob(os.path.join(dir_name, '*'))
name_set = set()
for path in path_list:
name = os.path.basename(path)
name = os.path.splitext(name)[0]
name_set.add(name)
return name_set
@staticmethod
def _list_abspath(data_dir, ext, data_list):
return [os.path.join(data_dir, name + ext)
for name in data_list]
class ImageFileTrain(ImageFile):
def __init__(
self,
alpha_dir="train_alpha",
fg_dir="train_fg",
bg_dir="train_bg",
alpha_ext=".jpg",
fg_ext=".jpg",
bg_ext=".jpg",
fg_have_bg_num=None,
alpha_ratio_json = None,
alpha_min_ratio = None,
key_sample_ratio = None,
):
super(ImageFileTrain, self).__init__(phase="train")
self.alpha_dir = alpha_dir
self.fg_dir = fg_dir
self.bg_dir = bg_dir
self.alpha_ext = alpha_ext
self.fg_ext = fg_ext
self.bg_ext = bg_ext
logger = setup_logger(name=__name__)
if not isinstance(self.alpha_dir, str):
assert len(self.alpha_dir) == len(self.fg_dir) == len(alpha_ext) == len(fg_ext)
self.valid_fg_list = []
self.alpha = []
self.fg = []
self.key_alpha = []
self.key_fg = []
for i in range(len(self.alpha_dir)):
valid_fg_list = self._get_valid_names(self.fg_dir[i], self.alpha_dir[i])
valid_fg_list.sort()
alpha = self._list_abspath(self.alpha_dir[i], self.alpha_ext[i], valid_fg_list)
fg = self._list_abspath(self.fg_dir[i], self.fg_ext[i], valid_fg_list)
self.valid_fg_list += valid_fg_list
self.alpha += alpha * fg_have_bg_num[i]
self.fg += fg * fg_have_bg_num[i]
if alpha_ratio_json[i] is not None:
tmp_key_alpha = []
tmp_key_fg = []
name_to_alpha_path = dict()
for name in alpha:
name_to_alpha_path[name.split('/')[-1].split('.')[0]] = name
name_to_fg_path = dict()
for name in fg:
name_to_fg_path[name.split('/')[-1].split('.')[0]] = name
with open(alpha_ratio_json[i], 'r') as file:
alpha_ratio_list = json.load(file)
for ratio, name in alpha_ratio_list:
if ratio < alpha_min_ratio[i]:
break
tmp_key_alpha.append(name_to_alpha_path[name.split('.')[0]])
tmp_key_fg.append(name_to_fg_path[name.split('.')[0]])
self.key_alpha.extend(tmp_key_alpha * fg_have_bg_num[i])
self.key_fg.extend(tmp_key_fg * fg_have_bg_num[i])
if len(self.key_alpha) != 0 and key_sample_ratio > 0:
repeat_num = key_sample_ratio * (len(self.alpha) - len(self.key_alpha)) / len(self.key_alpha) / (1 - key_sample_ratio) - 1
print('key sample num:', len(self.key_alpha), ', repeat num: ', repeat_num)
for i in range(math.ceil(repeat_num)):
self.alpha += self.key_alpha
self.fg += self.key_fg
else:
self.valid_fg_list = self._get_valid_names(self.fg_dir, self.alpha_dir)
self.valid_fg_list.sort()
self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_fg_list)
self.fg = self._list_abspath(self.fg_dir, self.fg_ext, self.valid_fg_list)
self.valid_bg_list = [os.path.splitext(name)[0] for name in os.listdir(self.bg_dir)]
self.valid_bg_list.sort()
if fg_have_bg_num is not None:
# assert fg_have_bg_num * len(self.valid_fg_list) <= len(self.valid_bg_list)
# self.valid_bg_list = self.valid_bg_list[: fg_have_bg_num * len(self.valid_fg_list)]
assert len(self.alpha) <= len(self.valid_bg_list)
self.valid_bg_list = self.valid_bg_list[: len(self.alpha)]
self.bg = self._list_abspath(self.bg_dir, self.bg_ext, self.valid_bg_list)
def __len__(self):
return len(self.alpha)
class ImageFileTest(ImageFile):
def __init__(self,
alpha_dir="test_alpha",
merged_dir="test_merged",
trimap_dir="test_trimap",
alpha_ext=".png",
merged_ext=".png",
trimap_ext=".png"):
super(ImageFileTest, self).__init__(phase="test")
self.alpha_dir = alpha_dir
self.merged_dir = merged_dir
self.trimap_dir = trimap_dir
self.alpha_ext = alpha_ext
self.merged_ext = merged_ext
self.trimap_ext = trimap_ext
self.valid_image_list = self._get_valid_names(self.alpha_dir, self.merged_dir, self.trimap_dir, shuffle=False)
self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_image_list)
self.merged = self._list_abspath(self.merged_dir, self.merged_ext, self.valid_image_list)
self.trimap = self._list_abspath(self.trimap_dir, self.trimap_ext, self.valid_image_list)
def __len__(self):
return len(self.alpha)
interp_list = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]
def maybe_random_interp(cv2_interp):
if CONFIG.data.random_interp:
return np.random.choice(interp_list)
else:
return cv2_interp
class ToTensor(object):
"""
Convert ndarrays in sample to Tensors with normalization.
"""
def __init__(self, phase="test"):
self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
self.std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
self.phase = phase
def __call__(self, sample):
image, alpha, trimap, mask = sample['image'][:,:,::-1], sample['alpha'], sample['trimap'], sample['mask']
alpha[alpha < 0 ] = 0
alpha[alpha > 1] = 1
image = image.transpose((2, 0, 1)).astype(np.float32)
alpha = np.expand_dims(alpha.astype(np.float32), axis=0)
mask = np.expand_dims(mask.astype(np.float32), axis=0)
image /= 255.
if self.phase == "train":
fg = sample['fg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
sample['fg'] = torch.from_numpy(fg)
bg = sample['bg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
sample['bg'] = torch.from_numpy(bg)
sample['image'], sample['alpha'], sample['trimap'] = \
torch.from_numpy(image), torch.from_numpy(alpha), torch.from_numpy(trimap).to(torch.long)
sample['image'] = sample['image']
if CONFIG.model.trimap_channel == 3:
sample['trimap'] = F.one_hot(sample['trimap'], num_classes=3).permute(2,0,1).float()
elif CONFIG.model.trimap_channel == 1:
sample['trimap'] = sample['trimap'][None,...].float()
else:
raise NotImplementedError("CONFIG.model.trimap_channel can only be 3 or 1")
sample['trimap'][sample['trimap'] < 85] = 0
sample['trimap'][sample['trimap'] >= 170] = 1
sample['trimap'][sample['trimap'] >= 85] = 0.5
sample['mask'] = torch.from_numpy(mask).float()
return sample
class RandomAffine(object):
"""
Random affine translation
"""
def __init__(self, degrees, translate=None, scale=None, shear=None, flip=None, resample=False, fillcolor=0):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
"degrees should be a list or tuple and it must be of length 2."
self.degrees = degrees
if translate is not None:
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
"translate should be a list or tuple and it must be of length 2."
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
"scale should be a list or tuple and it must be of length 2."
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
if isinstance(shear, numbers.Number):
if shear < 0:
raise ValueError("If shear is a single number, it must be positive.")
self.shear = (-shear, shear)
else:
assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
"shear should be a list or tuple and it must be of length 2."
self.shear = shear
else:
self.shear = shear
self.resample = resample
self.fillcolor = fillcolor
self.flip = flip
@staticmethod
def get_params(degrees, translate, scale_ranges, shears, flip, img_size):
"""Get parameters for affine transformation
Returns:
sequence: params to be passed to the affine transformation
"""
angle = random.uniform(degrees[0], degrees[1])
if translate is not None:
max_dx = translate[0] * img_size[0]
max_dy = translate[1] * img_size[1]
translations = (np.round(random.uniform(-max_dx, max_dx)),
np.round(random.uniform(-max_dy, max_dy)))
else:
translations = (0, 0)
if scale_ranges is not None:
scale = (random.uniform(scale_ranges[0], scale_ranges[1]),
random.uniform(scale_ranges[0], scale_ranges[1]))
else:
scale = (1.0, 1.0)
if shears is not None:
shear = random.uniform(shears[0], shears[1])
else:
shear = 0.0
if flip is not None:
flip = (np.random.rand(2) < flip).astype(np.int32) * 2 - 1
return angle, translations, scale, shear, flip
def __call__(self, sample):
fg, alpha = sample['fg'], sample['alpha']
rows, cols, ch = fg.shape
if np.maximum(rows, cols) < 1024:
params = self.get_params((0, 0), self.translate, self.scale, self.shear, self.flip, fg.size)
else:
params = self.get_params(self.degrees, self.translate, self.scale, self.shear, self.flip, fg.size)
center = (cols * 0.5 + 0.5, rows * 0.5 + 0.5)
M = self._get_inverse_affine_matrix(center, *params)
M = np.array(M).reshape((2, 3))
fg = cv2.warpAffine(fg, M, (cols, rows),
flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP)
alpha = cv2.warpAffine(alpha, M, (cols, rows),
flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP)
sample['fg'], sample['alpha'] = fg, alpha
return sample
@ staticmethod
def _get_inverse_affine_matrix(center, angle, translate, scale, shear, flip):
angle = math.radians(angle)
shear = math.radians(shear)
scale_x = 1.0 / scale[0] * flip[0]
scale_y = 1.0 / scale[1] * flip[1]
# Inverted rotation matrix with scale and shear
d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
matrix = [
math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0,
-math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0
]
matrix = [m / d for m in matrix]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
matrix[2] += center[0]
matrix[5] += center[1]
return matrix
class RandomJitter(object):
"""
Random change the hue of the image
"""
def __call__(self, sample):
sample_ori = sample.copy()
fg, alpha = sample['fg'], sample['alpha']
# if alpha is all 0 skip
if np.all(alpha==0):
return sample_ori
# convert to HSV space, convert to float32 image to keep precision during space conversion.
fg = cv2.cvtColor(fg.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV)
# Hue noise
hue_jitter = np.random.randint(-40, 40)
fg[:, :, 0] = np.remainder(fg[:, :, 0].astype(np.float32) + hue_jitter, 360)
# Saturation noise
sat_bar = fg[:, :, 1][alpha > 0].mean()
if np.isnan(sat_bar):
return sample_ori
sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10
sat = fg[:, :, 1]
sat = np.abs(sat + sat_jitter)
sat[sat>1] = 2 - sat[sat>1]
fg[:, :, 1] = sat
# Value noise
val_bar = fg[:, :, 2][alpha > 0].mean()
if np.isnan(val_bar):
return sample_ori
val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10
val = fg[:, :, 2]
val = np.abs(val + val_jitter)
val[val>1] = 2 - val[val>1]
fg[:, :, 2] = val
# convert back to BGR space
fg = cv2.cvtColor(fg, cv2.COLOR_HSV2BGR)
sample['fg'] = fg*255
return sample
class RandomHorizontalFlip(object):
"""
Random flip image and label horizontally
"""
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, sample):
fg, alpha = sample['fg'], sample['alpha']
if np.random.uniform(0, 1) < self.prob:
fg = cv2.flip(fg, 1)
alpha = cv2.flip(alpha, 1)
sample['fg'], sample['alpha'] = fg, alpha
return sample
class RandomCrop(object):
"""
Crop randomly the image in a sample, retain the center 1/4 images, and resize to 'output_size'
:param output_size (tuple or int): Desired output size. If int, square crop
is made.
"""
def __init__(self, output_size=( CONFIG.data.crop_size, CONFIG.data.crop_size)):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
self.margin = output_size[0] // 2
self.logger = logging.getLogger("Logger")
def __call__(self, sample):
fg, alpha, trimap, mask, name = sample['fg'], sample['alpha'], sample['trimap'], sample['mask'], sample['image_name']
bg = sample['bg']
h, w = trimap.shape
bg = cv2.resize(bg, (w, h), interpolation=maybe_random_interp(cv2.INTER_CUBIC))
if w < self.output_size[0]+1 or h < self.output_size[1]+1:
ratio = 1.1*self.output_size[0]/h if h < w else 1.1*self.output_size[1]/w
# self.logger.warning("Size of {} is {}.".format(name, (h, w)))
while h < self.output_size[0]+1 or w < self.output_size[1]+1:
fg = cv2.resize(fg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha = cv2.resize(alpha, (int(w*ratio), int(h*ratio)),
interpolation=maybe_random_interp(cv2.INTER_NEAREST))
trimap = cv2.resize(trimap, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST)
bg = cv2.resize(bg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_CUBIC))
mask = cv2.resize(mask, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST)
h, w = trimap.shape
small_trimap = cv2.resize(trimap, (w//4, h//4), interpolation=cv2.INTER_NEAREST)
unknown_list = list(zip(*np.where(small_trimap[self.margin//4:(h-self.margin)//4,
self.margin//4:(w-self.margin)//4] == 128)))
unknown_num = len(unknown_list)
if len(unknown_list) < 10:
left_top = (np.random.randint(0, h-self.output_size[0]+1), np.random.randint(0, w-self.output_size[1]+1))
else:
idx = np.random.randint(unknown_num)
left_top = (unknown_list[idx][0]*4, unknown_list[idx][1]*4)
fg_crop = fg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
alpha_crop = alpha[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
bg_crop = bg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
trimap_crop = trimap[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
mask_crop = mask[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
if len(np.where(trimap==128)[0]) == 0:
self.logger.error("{} does not have enough unknown area for crop. Resized to target size."
"left_top: {}".format(name, left_top))
fg_crop = cv2.resize(fg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha_crop = cv2.resize(alpha, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST))
trimap_crop = cv2.resize(trimap, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
bg_crop = cv2.resize(bg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_CUBIC))
mask_crop = cv2.resize(mask, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
sample.update({'fg': fg_crop, 'alpha': alpha_crop, 'trimap': trimap_crop, 'mask': mask_crop, 'bg': bg_crop})
return sample
class OriginScale(object):
def __call__(self, sample):
h, w = sample["alpha_shape"]
if h % 32 == 0 and w % 32 == 0:
return sample
target_h = 32 * ((h - 1) // 32 + 1)
target_w = 32 * ((w - 1) // 32 + 1)
pad_h = target_h - h
pad_w = target_w - w
padded_image = np.pad(sample['image'], ((0,pad_h), (0, pad_w), (0,0)), mode="reflect")
padded_trimap = np.pad(sample['trimap'], ((0,pad_h), (0, pad_w)), mode="reflect")
padded_mask = np.pad(sample['mask'], ((0,pad_h), (0, pad_w)), mode="reflect")
sample['image'] = padded_image
sample['trimap'] = padded_trimap
sample['mask'] = padded_mask
return sample
class GenMask(object):
def __init__(self):
self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,30)]
def __call__(self, sample):
alpha_ori = sample['alpha']
h, w = alpha_ori.shape
max_kernel_size = 30
alpha = cv2.resize(alpha_ori, (640,640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
### generate trimap
fg_mask = (alpha + 1e-5).astype(np.int32).astype(np.uint8)
bg_mask = (1 - alpha + 1e-5).astype(np.int32).astype(np.uint8)
fg_mask = cv2.erode(fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
bg_mask = cv2.erode(bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
fg_width = np.random.randint(1, 30)
bg_width = np.random.randint(1, 30)
fg_mask = (alpha + 1e-5).astype(np.int32).astype(np.uint8)
bg_mask = (1 - alpha + 1e-5).astype(np.int32).astype(np.uint8)
fg_mask = cv2.erode(fg_mask, self.erosion_kernels[fg_width])
bg_mask = cv2.erode(bg_mask, self.erosion_kernels[bg_width])
trimap = np.ones_like(alpha) * 128
trimap[fg_mask == 1] = 255
trimap[bg_mask == 1] = 0
trimap = cv2.resize(trimap, (w,h), interpolation=cv2.INTER_NEAREST)
sample['trimap'] = trimap
### generate mask
low = 0.01
high = 1.0
thres = random.random() * (high - low) + low
seg_mask = (alpha >= thres).astype(np.int32).astype(np.uint8)
random_num = random.randint(0,3)
if random_num == 0:
seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
elif random_num == 1:
seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
elif random_num == 2:
seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
elif random_num == 3:
seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
seg_mask = cv2.resize(seg_mask, (w,h), interpolation=cv2.INTER_NEAREST)
sample['mask'] = seg_mask
return sample
class Composite(object):
def __call__(self, sample):
fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha']
alpha[alpha < 0 ] = 0
alpha[alpha > 1] = 1
fg[fg < 0 ] = 0
fg[fg > 255] = 255
bg[bg < 0 ] = 0
bg[bg > 255] = 255
image = fg * alpha[:, :, None] + bg * (1 - alpha[:, :, None])
sample['image'] = image
return sample
class CutMask(object):
def __init__(self, perturb_prob = 0):
self.perturb_prob = perturb_prob
def __call__(self, sample):
if np.random.rand() < self.perturb_prob:
return sample
mask = sample['mask'] # H x W, trimap 0--255, segmask 0--1, alpha 0--1
h, w = mask.shape
perturb_size_h, perturb_size_w = random.randint(h // 4, h // 2), random.randint(w // 4, w // 2)
x = random.randint(0, h - perturb_size_h)
y = random.randint(0, w - perturb_size_w)
x1 = random.randint(0, h - perturb_size_h)
y1 = random.randint(0, w - perturb_size_w)
mask[x:x+perturb_size_h, y:y+perturb_size_w] = mask[x1:x1+perturb_size_h, y1:y1+perturb_size_w].copy()
sample['mask'] = mask
return sample
class ScaleFg(object):
def __init__(self, min_scale_fg_scale=0.5, max_scale_fg_scale=1.0):
self.min_scale_fg_scale = min_scale_fg_scale
self.max_scale_fg_scale = max_scale_fg_scale
def __call__(self, sample):
scale_factor = np.random.uniform(low=self.min_scale_fg_scale, high=self.max_scale_fg_scale)
fg, alpha = sample['fg'], sample['alpha'] # np.array(): [H, W, 3] 0 ~ 255 , [H, W] 0.0 ~ 1.0
h, w = alpha.shape
scale_h, scale_w = int(h * scale_factor), int(w * scale_factor)
new_fg, new_alpha = np.zeros_like(fg), np.zeros_like(alpha)
fg = cv2.resize(fg, (scale_w, scale_h), interpolation=cv2.INTER_LINEAR)
alpha = cv2.resize(alpha, (scale_w, scale_h), interpolation=cv2.INTER_LINEAR)
if scale_factor <= 1:
offset_h, offset_w = np.random.randint(h - scale_h + 1), np.random.randint(w - scale_w + 1)
new_fg[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w, :] = fg
new_alpha[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w] = alpha
else:
offset_h, offset_w = np.random.randint(scale_h - h + 1), np.random.randint(scale_w - w + 1)
new_fg = fg[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w, :]
new_alpha = alpha[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w]
sample['fg'], sample['alpha'] = new_fg, new_alpha
return sample
class GenBBox(object):
def __init__(self, bbox_offset_factor = 0.1, random_crop_bbox = None, train_or_test = 'train', dataset_type = None, random_auto_matting=None):
self.bbox_offset_factor = bbox_offset_factor
self.random_crop_bbox = random_crop_bbox
self.train_or_test = train_or_test
self.dataset_type = dataset_type
self.random_auto_matting = random_auto_matting
def __call__(self, sample):
alpha = sample['alpha'] # [1, H, W] 0.0 ~ 1.0
indices = torch.nonzero(alpha[0], as_tuple=True)
if len(indices[0]) > 0:
min_x, min_y = torch.min(indices[1]), torch.min(indices[0])
max_x, max_y = torch.max(indices[1]), torch.max(indices[0])
if self.random_crop_bbox is not None and np.random.uniform(0, 1) < self.random_crop_bbox:
ori_h_w = (sample['alpha'].shape[-2], sample['alpha'].shape[-1])
sample['alpha'] = F.interpolate(sample['alpha'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0]
sample['image'] = F.interpolate(sample['image'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0]
sample['trimap'] = F.interpolate(sample['trimap'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='nearest')[0]
bbox = torch.tensor([[0, 0, ori_h_w[1] - 1, ori_h_w[0] - 1]])
elif self.bbox_offset_factor != 0:
bbox_w = max(1, max_x - min_x)
bbox_h = max(1, max_y - min_y)
offset_w = math.ceil(self.bbox_offset_factor * bbox_w)
offset_h = math.ceil(self.bbox_offset_factor * bbox_h)
min_x = max(0, min_x + np.random.randint(-offset_w, offset_w))
max_x = min(alpha.shape[2] - 1, max_x + np.random.randint(-offset_w, offset_w))
min_y = max(0, min_y + np.random.randint(-offset_h, offset_h))
max_y = min(alpha.shape[1] - 1, max_y + np.random.randint(-offset_h, offset_h))
bbox = torch.tensor([[min_x, min_y, max_x, max_y]])
else:
bbox = torch.tensor([[min_x, min_y, max_x, max_y]])
if self.random_auto_matting is not None and np.random.uniform(0, 1) < self.random_auto_matting:
bbox = torch.tensor([[0, 0, alpha.shape[2] - 1, alpha.shape[1] - 1]])
else:
bbox = torch.zeros(1, 4)
sample['bbox'] = bbox.float()
return sample
class DataGenerator(Dataset):
def __init__(
self,
data,
phase="train",
crop_size=512,
remove_multi_fg=False,
min_scale_fg_scale=None,
max_scale_fg_scale=None,
with_bbox = False,
bbox_offset_factor = None,
return_keys = None,
random_crop_bbox = None,
dataset_name = None,
random_auto_matting = None,
):
self.phase = phase
# self.crop_size = CONFIG.data.crop_size
self.crop_size = crop_size
self.remove_multi_fg = remove_multi_fg
self.with_bbox = with_bbox
self.bbox_offset_factor = bbox_offset_factor
self.alpha = data.alpha
self.return_keys = return_keys
self.random_crop_bbox = random_crop_bbox
self.dataset_name = dataset_name
self.random_auto_matting = random_auto_matting
if self.phase == "train":
self.fg = data.fg
self.bg = data.bg
self.merged = []
self.trimap = []
else:
self.fg = []
self.bg = []
self.merged = data.merged
self.trimap = data.trimap
train_trans = [
RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5),
GenMask(),
CutMask(perturb_prob=CONFIG.data.cutmask_prob),
RandomCrop((self.crop_size, self.crop_size)),
RandomJitter(),
Composite(),
ToTensor(phase="train")
]
if min_scale_fg_scale is not None:
train_trans.insert(0, ScaleFg(min_scale_fg_scale, max_scale_fg_scale))
if self.with_bbox:
train_trans.append(GenBBox(bbox_offset_factor=self.bbox_offset_factor, random_crop_bbox=self.random_crop_bbox, random_auto_matting=self.random_auto_matting))
test_trans = [ OriginScale(), ToTensor() ]
self.transform = {
'train':
transforms.Compose(train_trans),
'val':
transforms.Compose([
OriginScale(),
ToTensor()
]),
'test':
transforms.Compose(test_trans)
}[phase]
self.fg_num = len(self.fg)
def select_keys(self, sample):
new_sample = {}
for key, val in sample.items():
if key in self.return_keys:
new_sample[key] = val
return new_sample
def __getitem__(self, idx):
if self.phase == "train":
fg = cv2.imread(self.fg[idx % self.fg_num])
alpha = cv2.imread(self.alpha[idx % self.fg_num], 0).astype(np.float32)/255
bg = cv2.imread(self.bg[idx], 1)
if not self.remove_multi_fg:
fg, alpha, multi_fg = self._composite_fg(fg, alpha, idx)
else:
multi_fg = False
image_name = os.path.split(self.fg[idx % self.fg_num])[-1]
sample = {'fg': fg, 'alpha': alpha, 'bg': bg, 'image_name': image_name, 'multi_fg': multi_fg}
else:
image = cv2.imread(self.merged[idx])
alpha = cv2.imread(self.alpha[idx], 0)/255.
trimap = cv2.imread(self.trimap[idx], 0)
mask = (trimap >= 170).astype(np.float32)
image_name = os.path.split(self.merged[idx])[-1]
sample = {'image': image, 'alpha': alpha, 'trimap': trimap, 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape}
sample = self.transform(sample)
if self.return_keys is not None:
sample = self.select_keys(sample)
if self.dataset_name is not None:
sample['dataset_name'] = self.dataset_name
return sample
def _composite_fg(self, fg, alpha, idx):
multi_fg = False
if np.random.rand() < 0.5:
idx2 = np.random.randint(self.fg_num) + idx
fg2 = cv2.imread(self.fg[idx2 % self.fg_num])
alpha2 = cv2.imread(self.alpha[idx2 % self.fg_num], 0).astype(np.float32)/255.
h, w = alpha.shape
fg2 = cv2.resize(fg2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha2 = cv2.resize(alpha2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha_tmp = 1 - (1 - alpha) * (1 - alpha2)
if np.any(alpha_tmp < 1):
fg = fg.astype(np.float32) * alpha[:,:,None] + fg2.astype(np.float32) * (1 - alpha[:,:,None])
# The overlap of two 50% transparency should be 25%
alpha = alpha_tmp
fg = fg.astype(np.uint8)
multi_fg = True
if np.random.rand() < 0.25:
# fg = cv2.resize(fg, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
# alpha = cv2.resize(alpha, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
fg = cv2.resize(fg, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha = cv2.resize(alpha, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
return fg, alpha, multi_fg
def __len__(self):
if self.phase == "train":
return len(self.bg)
else:
return len(self.alpha)
class ResziePad(object):
def __init__(self, target_size=1024):
self.target_size = target_size
def __call__(self, sample):
_, H, W = sample['image'].shape
scale = self.target_size * 1.0 / max(H, W)
new_H, new_W = H * scale, W * scale
new_W = int(new_W + 0.5)
new_H = int(new_H + 0.5)
choice = {'image', 'trimap', 'alpha'} if 'trimap' in sample.keys() else {'image', 'alpha'}
for key in choice:
if key in {'image', 'trimap'}:
sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='bilinear', align_corners=False)[0]
else:
# sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='nearest')[0]
sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='bilinear', align_corners=False)[0]
padding = torch.zeros([sample[key].shape[0], self.target_size, self.target_size], dtype=sample[key].dtype, device=sample[key].device)
padding[:, : new_H, : new_W] = sample[key]
sample[key] = padding
return sample
class Cv2ResziePad(object):
def __init__(self, target_size=1024):
self.target_size = target_size
def __call__(self, sample):
H, W, _ = sample['image'].shape
scale = self.target_size * 1.0 / max(H, W)
new_H, new_W = H * scale, W * scale
new_W = int(new_W + 0.5)
new_H = int(new_H + 0.5)
choice = {'image', 'trimap', 'alpha'} if 'trimap' in sample.keys() and sample['trimap'] is not None else {'image', 'alpha'}
for key in choice:
sample[key] = cv2.resize(sample[key], (new_W, new_H), interpolation=cv2.INTER_LINEAR) # cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC
if key == 'image':
padding = np.zeros([self.target_size, self.target_size, sample[key].shape[-1]], dtype=sample[key].dtype)
padding[: new_H, : new_W, :] = sample[key]
sample[key] = padding
sample[key] = sample[key][:, :, ::-1].transpose((2, 0, 1)).astype(np.float32) #/ 255.0
else:
padding = np.zeros([self.target_size, self.target_size], dtype=sample[key].dtype)
padding[: new_H, : new_W] = sample[key]
sample[key] = padding
sample[key] = sample[key][None].astype(np.float32)
sample[key] = torch.from_numpy(sample[key])
return sample
class AdobeCompositionTest(Dataset):
def __init__(self, data_dir, target_size=1024, multi_fg=None):
self.data_dir = data_dir
self.file_names = sorted(os.listdir(os.path.join(self.data_dir, 'merged')))
test_trans = [
ResziePad(target_size=target_size),
GenBBox(bbox_offset_factor=0)
]
self.transform = transforms.Compose(test_trans)
self.multi_fg = multi_fg
def __len__(self): # 1000
return len(self.file_names)
def __getitem__(self, idx):
phas = Image.open(os.path.join(self.data_dir, 'alpha_copy', self.file_names[idx])).convert('L')
tris = Image.open(os.path.join(self.data_dir, 'trimaps', self.file_names[idx]))
imgs = Image.open(os.path.join(self.data_dir, 'merged', self.file_names[idx]))
sample = {
'ori_h_w': (imgs.size[1], imgs.size[0]),
'data_type': 'Adobe'
}
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
sample['image_name'] = 'Adobe_' + self.file_names[idx]
sample = self.transform(sample)
sample['trimap'][sample['trimap'] < 85] = 0
sample['trimap'][sample['trimap'] >= 170] = 1
sample['trimap'][sample['trimap'] >= 85] = 0.5
if self.multi_fg is not None:
sample['multi_fg'] = torch.tensor(self.multi_fg)
return sample
class SIMTest(Dataset):
def __init__(self, data_dir, target_size=1024, multi_fg=None):
self.data_dir = data_dir
self.file_names = sorted(glob.glob(os.path.join(*[data_dir, '*', 'alpha', '*']))) # [: 10]
test_trans = [
ResziePad(target_size=target_size),
GenBBox(bbox_offset_factor=0)
]
self.transform = transforms.Compose(test_trans)
self.multi_fg = multi_fg
def __len__(self): # 1000
return len(self.file_names)
def __getitem__(self, idx):
phas = Image.open(self.file_names[idx]).convert('L')
# tris = Image.open(self.file_names[idx].replace('alpha', 'trimap'))
imgs = Image.open(self.file_names[idx].replace('alpha', 'merged'))
sample = {
'ori_h_w': (imgs.size[1], imgs.size[0]),
'data_type': 'SIM'
}
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
# sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
sample['image_name'] = 'SIM_{}_{}'.format(self.file_names[idx].split('/')[-3], self.file_names[idx].split('/')[-1])
sample = self.transform(sample)
# sample['trimap'][sample['trimap'] < 85] = 0
# sample['trimap'][sample['trimap'] >= 170] = 1
# sample['trimap'][sample['trimap'] >= 85] = 0.5
if self.multi_fg is not None:
sample['multi_fg'] = torch.tensor(self.multi_fg)
return sample
class RW100Test(Dataset):
def __init__(self, data_dir, target_size=1024, multi_fg=None):
self.data_dir = data_dir
self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'mask', '*'])))
self.name_to_idx = dict()
for idx, file_name in enumerate(self.file_names):
self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
test_trans = [
ResziePad(target_size=target_size),
GenBBox(bbox_offset_factor=0, train_or_test='test', dataset_type='RW100')
]
self.transform = transforms.Compose(test_trans)
self.multi_fg = multi_fg
def __len__(self): # 1000
return len(self.file_names)
def __getitem__(self, idx):
phas = Image.open(self.file_names[idx]).convert('L')
imgs = Image.open(self.file_names[idx].replace('mask', 'image')[:-6] + '.jpg')
sample = {
'ori_h_w': (imgs.size[1], imgs.size[0]),
'data_type': 'RW100'
}
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
sample['image_name'] = 'RW100_' + self.file_names[idx].split('/')[-1]
sample = self.transform(sample)
if self.multi_fg is not None:
sample['multi_fg'] = torch.tensor(self.multi_fg)
return sample
class AIM500Test(Dataset):
def __init__(self, data_dir, target_size=1024, multi_fg=None):
self.data_dir = data_dir
self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'original', '*'])))
self.name_to_idx = dict()
for idx, file_name in enumerate(self.file_names):
self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
test_trans = [
ResziePad(target_size=target_size),
GenBBox(bbox_offset_factor=0)
]
self.transform = transforms.Compose(test_trans)
self.multi_fg = multi_fg
def __len__(self): # 1000
return len(self.file_names)
def __getitem__(self, idx):
phas = Image.open(self.file_names[idx].replace('original', 'mask').replace('jpg', 'png')).convert('L')
# tris = Image.open(self.file_names[idx].replace('original', 'trimap').replace('jpg', 'png')).convert('L')
imgs = Image.open(self.file_names[idx])
sample = {
'ori_h_w': (imgs.size[1], imgs.size[0]),
'data_type': 'AIM500'
}
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
# sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
sample['image_name'] = 'AIM500_' + self.file_names[idx].split('/')[-1]
sample = self.transform(sample)
# sample['trimap'][sample['trimap'] < 85] = 0
# sample['trimap'][sample['trimap'] >= 170] = 1
# sample['trimap'][sample['trimap'] >= 85] = 0.5
if self.multi_fg is not None:
sample['multi_fg'] = torch.tensor(self.multi_fg)
return sample
class RWP636Test(Dataset):
def __init__(self, data_dir, target_size=1024, multi_fg=None):
self.data_dir = data_dir
self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'image', '*'])))
self.name_to_idx = dict()
for idx, file_name in enumerate(self.file_names):
self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
test_trans = [
ResziePad(target_size=target_size),
GenBBox(bbox_offset_factor=0)
]
self.transform = transforms.Compose(test_trans)
self.multi_fg = multi_fg
def __len__(self): # 1000
return len(self.file_names)
def __getitem__(self, idx):
phas = Image.open(self.file_names[idx].replace('image', 'alpha').replace('jpg', 'png')).convert('L')
imgs = Image.open(self.file_names[idx])
sample = {
'ori_h_w': (imgs.size[1], imgs.size[0]),
'data_type': 'RWP636'
}
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
sample['image_name'] = 'RWP636_' + self.file_names[idx].split('/')[-1]
sample = self.transform(sample)
if self.multi_fg is not None:
sample['multi_fg'] = torch.tensor(self.multi_fg)
return sample
class AM2KTest(Dataset):
def __init__(self, data_dir, target_size=1024, multi_fg=None):
self.data_dir = data_dir
self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'validation/original', '*'])))
test_trans = [
ResziePad(target_size=target_size),
GenBBox(bbox_offset_factor=0)
]
self.transform = transforms.Compose(test_trans)
self.multi_fg = multi_fg
def __len__(self): # 1000
return len(self.file_names)
def __getitem__(self, idx):
phas = Image.open(self.file_names[idx].replace('original', 'mask').replace('jpg', 'png')).convert('L')
# tris = Image.open(self.file_names[idx].replace('original', 'trimap').replace('jpg', 'png')).convert('L')
imgs = Image.open(self.file_names[idx])
sample = {
'ori_h_w': (imgs.size[1], imgs.size[0]),
'data_type': 'AM2K'
}
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
# sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
sample['image_name'] = 'AM2K_' + self.file_names[idx].split('/')[-1]
sample = self.transform(sample)
# sample['trimap'][sample['trimap'] < 85] = 0
# sample['trimap'][sample['trimap'] >= 170] = 1
# sample['trimap'][sample['trimap'] >= 85] = 0.5
if self.multi_fg is not None:
sample['multi_fg'] = torch.tensor(self.multi_fg)
return sample
class P3M500Test(Dataset):
def __init__(self, data_dir, target_size=1024, multi_fg=None):
self.data_dir = data_dir
self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'original_image', '*'])))
self.name_to_idx = dict()
for idx, file_name in enumerate(self.file_names):
self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
test_trans = [
ResziePad(target_size=target_size),
GenBBox(bbox_offset_factor=0)
]
self.transform = transforms.Compose(test_trans)
self.multi_fg = multi_fg
def __len__(self): # 1000
return len(self.file_names)
def __getitem__(self, idx):
phas = Image.open(self.file_names[idx].replace('original_image', 'mask').replace('jpg', 'png')).convert('L')
# tris = Image.open(self.file_names[idx].replace('original_image', 'trimap').replace('jpg', 'png')).convert('L')
imgs = Image.open(self.file_names[idx])
sample = {
'ori_h_w': (imgs.size[1], imgs.size[0]),
'data_type': 'P3M500'
}
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0
# sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
sample['image_name'] = 'P3M500_' + self.file_names[idx].split('/')[-1]
sample = self.transform(sample)
# sample['trimap'][sample['trimap'] < 85] = 0
# sample['trimap'][sample['trimap'] >= 170] = 1
# sample['trimap'][sample['trimap'] >= 85] = 0.5
if self.multi_fg is not None:
sample['multi_fg'] = torch.tensor(self.multi_fg)
return sample
class MattingTest(Dataset):
def __init__(
self,
data_type,
data_dir,
image_sub_path,
alpha_sub_path,
trimpa_sub_path=None,
target_size=1024,
multi_fg=None,
):
self.data_type = data_type
self.data_dir = data_dir
self.image_paths = sorted(glob.glob(os.path.join(*[data_dir, image_sub_path])))
self.alpha_paths = sorted(glob.glob(os.path.join(*[data_dir, alpha_sub_path])))
self.trimpa_paths = sorted(glob.glob(os.path.join(*[data_dir, trimpa_sub_path]))) if trimpa_sub_path is not None else None
self.name_to_idx = dict()
for idx, file_name in enumerate(self.image_paths):
self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
test_trans = [
Cv2ResziePad(target_size=target_size),
GenBBox(bbox_offset_factor=0)
]
self.transform = transforms.Compose(test_trans)
self.multi_fg = multi_fg
def __len__(self): # 1000
return len(self.image_paths)
def __getitem__(self, idx):
img = cv2.imread(self.image_paths[idx])
sample = {
'image': img.astype(np.float32) / 255,
'alpha': cv2.imread(self.alpha_paths[idx], 0).astype(np.float32) / 255,
'trimap': cv2.imread(self.trimpa_paths[idx], 0) if self.trimpa_paths is not None else None,
'ori_h_w': (img.shape[0], img.shape[1]),
'data_type': self.data_type,
'image_name': self.data_type + '_' + self.image_paths[idx].split('/')[-1]
}
sample = self.transform(sample)
if self.trimpa_paths is not None:
sample['trimap'][sample['trimap'] < 85] = 0
sample['trimap'][sample['trimap'] >= 170] = 1
sample['trimap'][sample['trimap'] >= 85] = 0.5
else:
del sample['trimap']
if self.multi_fg is not None:
sample['multi_fg'] = torch.tensor(self.multi_fg)
return sample
def adobe_composition_collate_fn(batch):
new_batch = defaultdict(list)
for sub_batch in batch:
for key in sub_batch.keys():
new_batch[key].append(sub_batch[key])
for key in new_batch:
if isinstance(new_batch[key][0], torch.Tensor):
new_batch[key] = torch.stack(new_batch[key])
return dict(new_batch)
def build_d2_test_dataloader(
dataset,
mapper=None,
total_batch_size=None,
local_batch_size=None,
num_workers=0,
collate_fn=None
):
assert (total_batch_size is None) != (
local_batch_size is None
), "Either total_batch_size or local_batch_size must be specified"
world_size = comm.get_world_size()
if total_batch_size is not None:
assert (
total_batch_size > 0 and total_batch_size % world_size == 0
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
total_batch_size, world_size
)
batch_size = total_batch_size // world_size
if local_batch_size is not None:
batch_size = local_batch_size
logger = logging.getLogger(__name__)
if batch_size != 1:
logger.warning(
"When testing, batch size is set to 1. "
"This is the only mode that is supported for d2."
)
return build_detection_test_loader(
dataset=dataset,
mapper=mapper,
sampler=None,
num_workers=num_workers,
collate_fn=collate_fn,
)
class AdobeCompositionEvaluator(DatasetEvaluator):
def __init__(
self,
save_eval_results_step=-1,
output_dir=None,
eval_dataset_type=['Adobe'],
distributed=True,
eval_w_sam_hq_mask = False,
):
self.save_eval_results_step = save_eval_results_step
self.output_dir = output_dir
self.eval_index = 0
self.eval_dataset_type = eval_dataset_type
self.eval_w_sam_hq_mask = eval_w_sam_hq_mask
self._distributed = distributed
self._logger = logging.getLogger(__name__)
def reset(self):
self.eval_metric = dict()
for i in self.eval_dataset_type:
self.eval_metric[i + '_MSE'] = []
self.eval_metric[i + '_SAD'] = []
self.eval_metric[i + '_MAD'] = []
self.eval_metric[i + '_Grad'] = []
self.eval_metric[i + '_Conn'] = []
os.makedirs(self.output_dir, exist_ok=True) if self.output_dir is not None else None
def process(self, inputs, outputs):
"""
Args:
inputs: {'alpha', 'trimap', 'image', 'bbox', 'image_name'}
outputs: [1, 1, H, W] 0. ~ 1.
"""
# crop the black pad area
assert inputs['image'].shape[-1] == inputs['image'].shape[-2] == 1024 and len(inputs['ori_h_w']) == 1
inputs['ori_h_w'] = inputs['ori_h_w'][0]
before_pad_h, before_pad_w = int(1024 / max(inputs['ori_h_w']) * inputs['ori_h_w'][0] + 0.5), int(1024 / max(inputs['ori_h_w']) * inputs['ori_h_w'][1] + 0.5)
inputs['image'] = inputs['image'][:, :, :before_pad_h, :before_pad_w]
inputs['alpha'] = inputs['alpha'][:, :, :before_pad_h, :before_pad_w]
if self.eval_w_sam_hq_mask:
outputs, samhq_low_res_masks = outputs[0][:, :, :before_pad_h, :before_pad_w], outputs[1][:, :, :before_pad_h, :before_pad_w]
pred_alpha, label_alpha, samhq_low_res_masks = outputs.cpu().numpy(), inputs['alpha'].numpy(), (samhq_low_res_masks > 0).float().cpu()
else:
outputs = outputs[:, :, :before_pad_h, :before_pad_w]
pred_alpha, label_alpha = outputs.cpu().numpy(), inputs['alpha'].numpy()
# if 'trimap' in inputs.keys():
# inputs['trimap'] = inputs['trimap'][:, :, :before_pad_h, :before_pad_w]
# trimap = inputs['trimap'].numpy()
# assert np.max(trimap) <= 1 and len(np.unique(trimap)) <= 3
# sad_loss_unknown = compute_sad_loss(pred_alpha, label_alpha, trimap, area='unknown')
# mse_loss_unknown = compute_mse_loss(pred_alpha, label_alpha, trimap, area='unknown')
# self.eval_metric[inputs['data_type'][0] + '_unknown_mse (1e-3)'].append(mse_loss_unknown)
# self.eval_metric[inputs['data_type'][0] + '_unknown_sad (1e3)'].append(sad_loss_unknown)
# calculate loss
assert np.max(pred_alpha) <= 1 and np.max(label_alpha) <= 1
eval_pred = np.uint8(pred_alpha[0, 0] * 255.0 + 0.5) * 1.0
eval_gt = label_alpha[0, 0] * 255.0
detailmap = np.zeros_like(eval_gt) + 128
mse_loss_ = compute_mse_loss(eval_pred, eval_gt, detailmap)
sad_loss_ = compute_sad_loss(eval_pred, eval_gt, detailmap)[0]
mad_loss_ = compute_mad_loss(eval_pred, eval_gt, detailmap)
grad_loss_ = compute_gradient_loss(eval_pred, eval_gt, detailmap)
conn_loss_ = compute_connectivity_error(eval_pred, eval_gt, detailmap)
self.eval_metric[inputs['data_type'][0] + '_MSE'].append(mse_loss_)
self.eval_metric[inputs['data_type'][0] + '_SAD'].append(sad_loss_)
self.eval_metric[inputs['data_type'][0] + '_MAD'].append(mad_loss_)
self.eval_metric[inputs['data_type'][0] + '_Grad'].append(grad_loss_)
self.eval_metric[inputs['data_type'][0] + '_Conn'].append(conn_loss_)
# vis results
if self.save_eval_results_step != -1 and self.eval_index % self.save_eval_results_step == 0:
if self.eval_w_sam_hq_mask:
self.save_vis_results(inputs, pred_alpha, samhq_low_res_masks)
else:
self.save_vis_results(inputs, pred_alpha)
self.eval_index += 1
def save_vis_results(self, inputs, pred_alpha, samhq_low_res_masks=None):
# image
image = inputs['image'][0].permute(1, 2, 0) * 255.0
l, u, r, d = int(inputs['bbox'][0, 0, 0].item()), int(inputs['bbox'][0, 0, 1].item()), int(inputs['bbox'][0, 0, 2].item()), int(inputs['bbox'][0, 0, 3].item())
red_line = torch.tensor([[255., 0., 0.]], device=image.device, dtype=image.dtype)
image[u: d, l, :] = red_line
image[u: d, r, :] = red_line
image[u, l: r, :] = red_line
image[d, l: r, :] = red_line
image = np.uint8(image.numpy())
# trimap, pred_alpha, label_alpha
save_results = [image]
choice = [inputs['trimap'], torch.from_numpy(pred_alpha), inputs['alpha']] if 'trimap' in inputs.keys() else [torch.from_numpy(pred_alpha), inputs['alpha']]
for val in choice:
val = val[0].permute(1, 2, 0).repeat(1, 1, 3) * 255.0 + 0.5 # +0.5 and int() = round()
val = np.uint8(val.numpy())
save_results.append(val)
if samhq_low_res_masks is not None:
save_results.append(np.uint8(samhq_low_res_masks[0].permute(1, 2, 0).repeat(1, 1, 3).numpy() * 255.0))
save_results = np.concatenate(save_results, axis=1)
save_name = os.path.join(self.output_dir, inputs['image_name'][0])
Image.fromarray(save_results).save(save_name.replace('.jpg', '.png'))
def evaluate(self):
if self._distributed:
comm.synchronize()
eval_metric = comm.gather(self.eval_metric, dst=0)
if not comm.is_main_process():
return {}
merges_eval_metric = defaultdict(list)
for sub_eval_metric in eval_metric:
for key, val in sub_eval_metric.items():
merges_eval_metric[key] += val
eval_metric = merges_eval_metric
else:
eval_metric = self.eval_metric
eval_results = {}
for key, val in eval_metric.items():
if len(val) != 0:
# if 'mse' in key:
# eval_results[key] = np.array(val).mean() * 1e3
# else:
# assert 'sad' in key
# eval_results[key] = np.array(val).mean() / 1e3
eval_results[key] = np.array(val).mean()
return eval_results
if __name__ == '__main__':
pass