|
''' |
|
Dataloader to process Adobe Image Matting Dataset. |
|
|
|
From GCA_Matting(https://github.com/Yaoyi-Li/GCA-Matting/tree/master/dataloader) |
|
''' |
|
import os |
|
import glob |
|
import logging |
|
import os.path as osp |
|
import functools |
|
import numpy as np |
|
import torch |
|
import cv2 |
|
import math |
|
import numbers |
|
import random |
|
import pickle |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.nn import functional as F |
|
from torchvision import transforms |
|
from easydict import EasyDict |
|
from detectron2.utils.logger import setup_logger |
|
from detectron2.utils import comm |
|
from detectron2.data import build_detection_test_loader |
|
import torchvision.transforms.functional |
|
|
|
import json |
|
from PIL import Image |
|
from detectron2.evaluation.evaluator import DatasetEvaluator |
|
from collections import defaultdict |
|
|
|
from data.evaluate import compute_sad_loss, compute_mse_loss, compute_mad_loss, compute_gradient_loss, compute_connectivity_error |
|
|
|
|
|
CONFIG = EasyDict({}) |
|
|
|
|
|
CONFIG.model = EasyDict({}) |
|
|
|
CONFIG.model.trimap_channel = 1 |
|
|
|
|
|
CONFIG.data = EasyDict({}) |
|
|
|
CONFIG.data.crop_size = 512 |
|
|
|
CONFIG.data.cutmask_prob = 0.25 |
|
CONFIG.data.augmentation = True |
|
CONFIG.data.random_interp = True |
|
|
|
class Prefetcher(): |
|
""" |
|
Modified from the data_prefetcher in https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py |
|
""" |
|
def __init__(self, loader): |
|
self.orig_loader = loader |
|
self.stream = torch.cuda.Stream() |
|
self.next_sample = None |
|
|
|
def preload(self): |
|
try: |
|
self.next_sample = next(self.loader) |
|
except StopIteration: |
|
self.next_sample = None |
|
return |
|
|
|
with torch.cuda.stream(self.stream): |
|
for key, value in self.next_sample.items(): |
|
if isinstance(value, torch.Tensor): |
|
self.next_sample[key] = value.cuda(non_blocking=True) |
|
|
|
def __next__(self): |
|
torch.cuda.current_stream().wait_stream(self.stream) |
|
sample = self.next_sample |
|
if sample is not None: |
|
for key, value in sample.items(): |
|
if isinstance(value, torch.Tensor): |
|
sample[key].record_stream(torch.cuda.current_stream()) |
|
self.preload() |
|
else: |
|
|
|
raise StopIteration("No samples in loader. example: `iterator = iter(Prefetcher(loader)); " |
|
"data = next(iterator)`") |
|
return sample |
|
|
|
def __iter__(self): |
|
self.loader = iter(self.orig_loader) |
|
self.preload() |
|
return self |
|
|
|
|
|
class ImageFile(object): |
|
def __init__(self, phase='train'): |
|
self.phase = phase |
|
self.rng = np.random.RandomState(0) |
|
|
|
def _get_valid_names(self, *dirs, shuffle=True): |
|
name_sets = [self._get_name_set(d) for d in dirs] |
|
|
|
def _join_and(a, b): |
|
return a & b |
|
|
|
valid_names = list(functools.reduce(_join_and, name_sets)) |
|
if shuffle: |
|
self.rng.shuffle(valid_names) |
|
|
|
return valid_names |
|
|
|
@staticmethod |
|
def _get_name_set(dir_name): |
|
path_list = glob.glob(os.path.join(dir_name, '*')) |
|
name_set = set() |
|
for path in path_list: |
|
name = os.path.basename(path) |
|
name = os.path.splitext(name)[0] |
|
name_set.add(name) |
|
return name_set |
|
|
|
@staticmethod |
|
def _list_abspath(data_dir, ext, data_list): |
|
return [os.path.join(data_dir, name + ext) |
|
for name in data_list] |
|
|
|
class ImageFileTrain(ImageFile): |
|
def __init__( |
|
self, |
|
alpha_dir="train_alpha", |
|
fg_dir="train_fg", |
|
bg_dir="train_bg", |
|
alpha_ext=".jpg", |
|
fg_ext=".jpg", |
|
bg_ext=".jpg", |
|
fg_have_bg_num=None, |
|
alpha_ratio_json = None, |
|
alpha_min_ratio = None, |
|
key_sample_ratio = None, |
|
): |
|
super(ImageFileTrain, self).__init__(phase="train") |
|
|
|
self.alpha_dir = alpha_dir |
|
self.fg_dir = fg_dir |
|
self.bg_dir = bg_dir |
|
self.alpha_ext = alpha_ext |
|
self.fg_ext = fg_ext |
|
self.bg_ext = bg_ext |
|
logger = setup_logger(name=__name__) |
|
|
|
if not isinstance(self.alpha_dir, str): |
|
assert len(self.alpha_dir) == len(self.fg_dir) == len(alpha_ext) == len(fg_ext) |
|
self.valid_fg_list = [] |
|
self.alpha = [] |
|
self.fg = [] |
|
self.key_alpha = [] |
|
self.key_fg = [] |
|
for i in range(len(self.alpha_dir)): |
|
valid_fg_list = self._get_valid_names(self.fg_dir[i], self.alpha_dir[i]) |
|
valid_fg_list.sort() |
|
alpha = self._list_abspath(self.alpha_dir[i], self.alpha_ext[i], valid_fg_list) |
|
fg = self._list_abspath(self.fg_dir[i], self.fg_ext[i], valid_fg_list) |
|
self.valid_fg_list += valid_fg_list |
|
|
|
self.alpha += alpha * fg_have_bg_num[i] |
|
self.fg += fg * fg_have_bg_num[i] |
|
|
|
if alpha_ratio_json[i] is not None: |
|
tmp_key_alpha = [] |
|
tmp_key_fg = [] |
|
name_to_alpha_path = dict() |
|
for name in alpha: |
|
name_to_alpha_path[name.split('/')[-1].split('.')[0]] = name |
|
name_to_fg_path = dict() |
|
for name in fg: |
|
name_to_fg_path[name.split('/')[-1].split('.')[0]] = name |
|
|
|
with open(alpha_ratio_json[i], 'r') as file: |
|
alpha_ratio_list = json.load(file) |
|
for ratio, name in alpha_ratio_list: |
|
if ratio < alpha_min_ratio[i]: |
|
break |
|
tmp_key_alpha.append(name_to_alpha_path[name.split('.')[0]]) |
|
tmp_key_fg.append(name_to_fg_path[name.split('.')[0]]) |
|
|
|
self.key_alpha.extend(tmp_key_alpha * fg_have_bg_num[i]) |
|
self.key_fg.extend(tmp_key_fg * fg_have_bg_num[i]) |
|
|
|
if len(self.key_alpha) != 0 and key_sample_ratio > 0: |
|
repeat_num = key_sample_ratio * (len(self.alpha) - len(self.key_alpha)) / len(self.key_alpha) / (1 - key_sample_ratio) - 1 |
|
print('key sample num:', len(self.key_alpha), ', repeat num: ', repeat_num) |
|
for i in range(math.ceil(repeat_num)): |
|
self.alpha += self.key_alpha |
|
self.fg += self.key_fg |
|
|
|
else: |
|
self.valid_fg_list = self._get_valid_names(self.fg_dir, self.alpha_dir) |
|
self.valid_fg_list.sort() |
|
self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_fg_list) |
|
self.fg = self._list_abspath(self.fg_dir, self.fg_ext, self.valid_fg_list) |
|
|
|
self.valid_bg_list = [os.path.splitext(name)[0] for name in os.listdir(self.bg_dir)] |
|
self.valid_bg_list.sort() |
|
|
|
if fg_have_bg_num is not None: |
|
|
|
|
|
assert len(self.alpha) <= len(self.valid_bg_list) |
|
self.valid_bg_list = self.valid_bg_list[: len(self.alpha)] |
|
|
|
self.bg = self._list_abspath(self.bg_dir, self.bg_ext, self.valid_bg_list) |
|
|
|
def __len__(self): |
|
return len(self.alpha) |
|
|
|
class ImageFileTest(ImageFile): |
|
def __init__(self, |
|
alpha_dir="test_alpha", |
|
merged_dir="test_merged", |
|
trimap_dir="test_trimap", |
|
alpha_ext=".png", |
|
merged_ext=".png", |
|
trimap_ext=".png"): |
|
super(ImageFileTest, self).__init__(phase="test") |
|
|
|
self.alpha_dir = alpha_dir |
|
self.merged_dir = merged_dir |
|
self.trimap_dir = trimap_dir |
|
self.alpha_ext = alpha_ext |
|
self.merged_ext = merged_ext |
|
self.trimap_ext = trimap_ext |
|
|
|
self.valid_image_list = self._get_valid_names(self.alpha_dir, self.merged_dir, self.trimap_dir, shuffle=False) |
|
|
|
self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_image_list) |
|
self.merged = self._list_abspath(self.merged_dir, self.merged_ext, self.valid_image_list) |
|
self.trimap = self._list_abspath(self.trimap_dir, self.trimap_ext, self.valid_image_list) |
|
|
|
def __len__(self): |
|
return len(self.alpha) |
|
|
|
interp_list = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4] |
|
|
|
|
|
def maybe_random_interp(cv2_interp): |
|
if CONFIG.data.random_interp: |
|
return np.random.choice(interp_list) |
|
else: |
|
return cv2_interp |
|
|
|
|
|
class ToTensor(object): |
|
""" |
|
Convert ndarrays in sample to Tensors with normalization. |
|
""" |
|
def __init__(self, phase="test"): |
|
self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1) |
|
self.std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) |
|
self.phase = phase |
|
|
|
def __call__(self, sample): |
|
image, alpha, trimap, mask = sample['image'][:,:,::-1], sample['alpha'], sample['trimap'], sample['mask'] |
|
|
|
alpha[alpha < 0 ] = 0 |
|
alpha[alpha > 1] = 1 |
|
|
|
image = image.transpose((2, 0, 1)).astype(np.float32) |
|
alpha = np.expand_dims(alpha.astype(np.float32), axis=0) |
|
|
|
mask = np.expand_dims(mask.astype(np.float32), axis=0) |
|
|
|
image /= 255. |
|
|
|
if self.phase == "train": |
|
fg = sample['fg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255. |
|
sample['fg'] = torch.from_numpy(fg) |
|
bg = sample['bg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255. |
|
sample['bg'] = torch.from_numpy(bg) |
|
|
|
sample['image'], sample['alpha'], sample['trimap'] = \ |
|
torch.from_numpy(image), torch.from_numpy(alpha), torch.from_numpy(trimap).to(torch.long) |
|
sample['image'] = sample['image'] |
|
|
|
if CONFIG.model.trimap_channel == 3: |
|
sample['trimap'] = F.one_hot(sample['trimap'], num_classes=3).permute(2,0,1).float() |
|
elif CONFIG.model.trimap_channel == 1: |
|
sample['trimap'] = sample['trimap'][None,...].float() |
|
else: |
|
raise NotImplementedError("CONFIG.model.trimap_channel can only be 3 or 1") |
|
sample['trimap'][sample['trimap'] < 85] = 0 |
|
sample['trimap'][sample['trimap'] >= 170] = 1 |
|
sample['trimap'][sample['trimap'] >= 85] = 0.5 |
|
|
|
sample['mask'] = torch.from_numpy(mask).float() |
|
|
|
return sample |
|
|
|
|
|
class RandomAffine(object): |
|
""" |
|
Random affine translation |
|
""" |
|
def __init__(self, degrees, translate=None, scale=None, shear=None, flip=None, resample=False, fillcolor=0): |
|
if isinstance(degrees, numbers.Number): |
|
if degrees < 0: |
|
raise ValueError("If degrees is a single number, it must be positive.") |
|
self.degrees = (-degrees, degrees) |
|
else: |
|
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ |
|
"degrees should be a list or tuple and it must be of length 2." |
|
self.degrees = degrees |
|
|
|
if translate is not None: |
|
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ |
|
"translate should be a list or tuple and it must be of length 2." |
|
for t in translate: |
|
if not (0.0 <= t <= 1.0): |
|
raise ValueError("translation values should be between 0 and 1") |
|
self.translate = translate |
|
|
|
if scale is not None: |
|
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ |
|
"scale should be a list or tuple and it must be of length 2." |
|
for s in scale: |
|
if s <= 0: |
|
raise ValueError("scale values should be positive") |
|
self.scale = scale |
|
|
|
if shear is not None: |
|
if isinstance(shear, numbers.Number): |
|
if shear < 0: |
|
raise ValueError("If shear is a single number, it must be positive.") |
|
self.shear = (-shear, shear) |
|
else: |
|
assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ |
|
"shear should be a list or tuple and it must be of length 2." |
|
self.shear = shear |
|
else: |
|
self.shear = shear |
|
|
|
self.resample = resample |
|
self.fillcolor = fillcolor |
|
self.flip = flip |
|
|
|
@staticmethod |
|
def get_params(degrees, translate, scale_ranges, shears, flip, img_size): |
|
"""Get parameters for affine transformation |
|
|
|
Returns: |
|
sequence: params to be passed to the affine transformation |
|
""" |
|
angle = random.uniform(degrees[0], degrees[1]) |
|
if translate is not None: |
|
max_dx = translate[0] * img_size[0] |
|
max_dy = translate[1] * img_size[1] |
|
translations = (np.round(random.uniform(-max_dx, max_dx)), |
|
np.round(random.uniform(-max_dy, max_dy))) |
|
else: |
|
translations = (0, 0) |
|
|
|
if scale_ranges is not None: |
|
scale = (random.uniform(scale_ranges[0], scale_ranges[1]), |
|
random.uniform(scale_ranges[0], scale_ranges[1])) |
|
else: |
|
scale = (1.0, 1.0) |
|
|
|
if shears is not None: |
|
shear = random.uniform(shears[0], shears[1]) |
|
else: |
|
shear = 0.0 |
|
|
|
if flip is not None: |
|
flip = (np.random.rand(2) < flip).astype(np.int32) * 2 - 1 |
|
|
|
return angle, translations, scale, shear, flip |
|
|
|
def __call__(self, sample): |
|
fg, alpha = sample['fg'], sample['alpha'] |
|
rows, cols, ch = fg.shape |
|
if np.maximum(rows, cols) < 1024: |
|
params = self.get_params((0, 0), self.translate, self.scale, self.shear, self.flip, fg.size) |
|
else: |
|
params = self.get_params(self.degrees, self.translate, self.scale, self.shear, self.flip, fg.size) |
|
|
|
center = (cols * 0.5 + 0.5, rows * 0.5 + 0.5) |
|
M = self._get_inverse_affine_matrix(center, *params) |
|
M = np.array(M).reshape((2, 3)) |
|
|
|
fg = cv2.warpAffine(fg, M, (cols, rows), |
|
flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP) |
|
alpha = cv2.warpAffine(alpha, M, (cols, rows), |
|
flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP) |
|
|
|
sample['fg'], sample['alpha'] = fg, alpha |
|
|
|
return sample |
|
|
|
|
|
@ staticmethod |
|
def _get_inverse_affine_matrix(center, angle, translate, scale, shear, flip): |
|
|
|
angle = math.radians(angle) |
|
shear = math.radians(shear) |
|
scale_x = 1.0 / scale[0] * flip[0] |
|
scale_y = 1.0 / scale[1] * flip[1] |
|
|
|
|
|
d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle) |
|
matrix = [ |
|
math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0, |
|
-math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0 |
|
] |
|
matrix = [m / d for m in matrix] |
|
|
|
|
|
matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1]) |
|
matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1]) |
|
|
|
|
|
matrix[2] += center[0] |
|
matrix[5] += center[1] |
|
|
|
return matrix |
|
|
|
|
|
class RandomJitter(object): |
|
""" |
|
Random change the hue of the image |
|
""" |
|
|
|
def __call__(self, sample): |
|
sample_ori = sample.copy() |
|
fg, alpha = sample['fg'], sample['alpha'] |
|
|
|
if np.all(alpha==0): |
|
return sample_ori |
|
|
|
fg = cv2.cvtColor(fg.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV) |
|
|
|
hue_jitter = np.random.randint(-40, 40) |
|
fg[:, :, 0] = np.remainder(fg[:, :, 0].astype(np.float32) + hue_jitter, 360) |
|
|
|
sat_bar = fg[:, :, 1][alpha > 0].mean() |
|
if np.isnan(sat_bar): |
|
return sample_ori |
|
sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10 |
|
sat = fg[:, :, 1] |
|
sat = np.abs(sat + sat_jitter) |
|
sat[sat>1] = 2 - sat[sat>1] |
|
fg[:, :, 1] = sat |
|
|
|
val_bar = fg[:, :, 2][alpha > 0].mean() |
|
if np.isnan(val_bar): |
|
return sample_ori |
|
val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10 |
|
val = fg[:, :, 2] |
|
val = np.abs(val + val_jitter) |
|
val[val>1] = 2 - val[val>1] |
|
fg[:, :, 2] = val |
|
|
|
fg = cv2.cvtColor(fg, cv2.COLOR_HSV2BGR) |
|
sample['fg'] = fg*255 |
|
|
|
return sample |
|
|
|
|
|
class RandomHorizontalFlip(object): |
|
""" |
|
Random flip image and label horizontally |
|
""" |
|
def __init__(self, prob=0.5): |
|
self.prob = prob |
|
def __call__(self, sample): |
|
fg, alpha = sample['fg'], sample['alpha'] |
|
if np.random.uniform(0, 1) < self.prob: |
|
fg = cv2.flip(fg, 1) |
|
alpha = cv2.flip(alpha, 1) |
|
sample['fg'], sample['alpha'] = fg, alpha |
|
|
|
return sample |
|
|
|
|
|
class RandomCrop(object): |
|
""" |
|
Crop randomly the image in a sample, retain the center 1/4 images, and resize to 'output_size' |
|
|
|
:param output_size (tuple or int): Desired output size. If int, square crop |
|
is made. |
|
""" |
|
|
|
def __init__(self, output_size=( CONFIG.data.crop_size, CONFIG.data.crop_size)): |
|
assert isinstance(output_size, (int, tuple)) |
|
if isinstance(output_size, int): |
|
self.output_size = (output_size, output_size) |
|
else: |
|
assert len(output_size) == 2 |
|
self.output_size = output_size |
|
self.margin = output_size[0] // 2 |
|
self.logger = logging.getLogger("Logger") |
|
|
|
def __call__(self, sample): |
|
fg, alpha, trimap, mask, name = sample['fg'], sample['alpha'], sample['trimap'], sample['mask'], sample['image_name'] |
|
bg = sample['bg'] |
|
h, w = trimap.shape |
|
bg = cv2.resize(bg, (w, h), interpolation=maybe_random_interp(cv2.INTER_CUBIC)) |
|
if w < self.output_size[0]+1 or h < self.output_size[1]+1: |
|
ratio = 1.1*self.output_size[0]/h if h < w else 1.1*self.output_size[1]/w |
|
|
|
while h < self.output_size[0]+1 or w < self.output_size[1]+1: |
|
fg = cv2.resize(fg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) |
|
alpha = cv2.resize(alpha, (int(w*ratio), int(h*ratio)), |
|
interpolation=maybe_random_interp(cv2.INTER_NEAREST)) |
|
trimap = cv2.resize(trimap, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST) |
|
bg = cv2.resize(bg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_CUBIC)) |
|
mask = cv2.resize(mask, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST) |
|
h, w = trimap.shape |
|
small_trimap = cv2.resize(trimap, (w//4, h//4), interpolation=cv2.INTER_NEAREST) |
|
unknown_list = list(zip(*np.where(small_trimap[self.margin//4:(h-self.margin)//4, |
|
self.margin//4:(w-self.margin)//4] == 128))) |
|
unknown_num = len(unknown_list) |
|
if len(unknown_list) < 10: |
|
left_top = (np.random.randint(0, h-self.output_size[0]+1), np.random.randint(0, w-self.output_size[1]+1)) |
|
else: |
|
idx = np.random.randint(unknown_num) |
|
left_top = (unknown_list[idx][0]*4, unknown_list[idx][1]*4) |
|
|
|
fg_crop = fg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:] |
|
alpha_crop = alpha[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]] |
|
bg_crop = bg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:] |
|
trimap_crop = trimap[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]] |
|
mask_crop = mask[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]] |
|
|
|
if len(np.where(trimap==128)[0]) == 0: |
|
self.logger.error("{} does not have enough unknown area for crop. Resized to target size." |
|
"left_top: {}".format(name, left_top)) |
|
fg_crop = cv2.resize(fg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST)) |
|
alpha_crop = cv2.resize(alpha, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST)) |
|
trimap_crop = cv2.resize(trimap, self.output_size[::-1], interpolation=cv2.INTER_NEAREST) |
|
bg_crop = cv2.resize(bg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_CUBIC)) |
|
mask_crop = cv2.resize(mask, self.output_size[::-1], interpolation=cv2.INTER_NEAREST) |
|
|
|
sample.update({'fg': fg_crop, 'alpha': alpha_crop, 'trimap': trimap_crop, 'mask': mask_crop, 'bg': bg_crop}) |
|
return sample |
|
|
|
|
|
class OriginScale(object): |
|
def __call__(self, sample): |
|
h, w = sample["alpha_shape"] |
|
|
|
if h % 32 == 0 and w % 32 == 0: |
|
return sample |
|
|
|
target_h = 32 * ((h - 1) // 32 + 1) |
|
target_w = 32 * ((w - 1) // 32 + 1) |
|
pad_h = target_h - h |
|
pad_w = target_w - w |
|
|
|
padded_image = np.pad(sample['image'], ((0,pad_h), (0, pad_w), (0,0)), mode="reflect") |
|
padded_trimap = np.pad(sample['trimap'], ((0,pad_h), (0, pad_w)), mode="reflect") |
|
padded_mask = np.pad(sample['mask'], ((0,pad_h), (0, pad_w)), mode="reflect") |
|
|
|
sample['image'] = padded_image |
|
sample['trimap'] = padded_trimap |
|
sample['mask'] = padded_mask |
|
|
|
return sample |
|
|
|
|
|
class GenMask(object): |
|
def __init__(self): |
|
self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,30)] |
|
|
|
def __call__(self, sample): |
|
alpha_ori = sample['alpha'] |
|
h, w = alpha_ori.shape |
|
|
|
max_kernel_size = 30 |
|
alpha = cv2.resize(alpha_ori, (640,640), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) |
|
|
|
|
|
fg_mask = (alpha + 1e-5).astype(np.int32).astype(np.uint8) |
|
bg_mask = (1 - alpha + 1e-5).astype(np.int32).astype(np.uint8) |
|
fg_mask = cv2.erode(fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) |
|
bg_mask = cv2.erode(bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) |
|
|
|
fg_width = np.random.randint(1, 30) |
|
bg_width = np.random.randint(1, 30) |
|
fg_mask = (alpha + 1e-5).astype(np.int32).astype(np.uint8) |
|
bg_mask = (1 - alpha + 1e-5).astype(np.int32).astype(np.uint8) |
|
fg_mask = cv2.erode(fg_mask, self.erosion_kernels[fg_width]) |
|
bg_mask = cv2.erode(bg_mask, self.erosion_kernels[bg_width]) |
|
|
|
trimap = np.ones_like(alpha) * 128 |
|
trimap[fg_mask == 1] = 255 |
|
trimap[bg_mask == 1] = 0 |
|
|
|
trimap = cv2.resize(trimap, (w,h), interpolation=cv2.INTER_NEAREST) |
|
sample['trimap'] = trimap |
|
|
|
|
|
low = 0.01 |
|
high = 1.0 |
|
thres = random.random() * (high - low) + low |
|
seg_mask = (alpha >= thres).astype(np.int32).astype(np.uint8) |
|
random_num = random.randint(0,3) |
|
if random_num == 0: |
|
seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) |
|
elif random_num == 1: |
|
seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) |
|
elif random_num == 2: |
|
seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) |
|
seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) |
|
elif random_num == 3: |
|
seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) |
|
seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) |
|
|
|
seg_mask = cv2.resize(seg_mask, (w,h), interpolation=cv2.INTER_NEAREST) |
|
sample['mask'] = seg_mask |
|
|
|
return sample |
|
|
|
|
|
class Composite(object): |
|
def __call__(self, sample): |
|
fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha'] |
|
alpha[alpha < 0 ] = 0 |
|
alpha[alpha > 1] = 1 |
|
fg[fg < 0 ] = 0 |
|
fg[fg > 255] = 255 |
|
bg[bg < 0 ] = 0 |
|
bg[bg > 255] = 255 |
|
|
|
image = fg * alpha[:, :, None] + bg * (1 - alpha[:, :, None]) |
|
sample['image'] = image |
|
return sample |
|
|
|
|
|
class CutMask(object): |
|
def __init__(self, perturb_prob = 0): |
|
self.perturb_prob = perturb_prob |
|
|
|
def __call__(self, sample): |
|
if np.random.rand() < self.perturb_prob: |
|
return sample |
|
|
|
mask = sample['mask'] |
|
h, w = mask.shape |
|
perturb_size_h, perturb_size_w = random.randint(h // 4, h // 2), random.randint(w // 4, w // 2) |
|
x = random.randint(0, h - perturb_size_h) |
|
y = random.randint(0, w - perturb_size_w) |
|
x1 = random.randint(0, h - perturb_size_h) |
|
y1 = random.randint(0, w - perturb_size_w) |
|
|
|
mask[x:x+perturb_size_h, y:y+perturb_size_w] = mask[x1:x1+perturb_size_h, y1:y1+perturb_size_w].copy() |
|
|
|
sample['mask'] = mask |
|
return sample |
|
|
|
|
|
class ScaleFg(object): |
|
def __init__(self, min_scale_fg_scale=0.5, max_scale_fg_scale=1.0): |
|
self.min_scale_fg_scale = min_scale_fg_scale |
|
self.max_scale_fg_scale = max_scale_fg_scale |
|
|
|
def __call__(self, sample): |
|
scale_factor = np.random.uniform(low=self.min_scale_fg_scale, high=self.max_scale_fg_scale) |
|
|
|
fg, alpha = sample['fg'], sample['alpha'] |
|
h, w = alpha.shape |
|
scale_h, scale_w = int(h * scale_factor), int(w * scale_factor) |
|
|
|
new_fg, new_alpha = np.zeros_like(fg), np.zeros_like(alpha) |
|
fg = cv2.resize(fg, (scale_w, scale_h), interpolation=cv2.INTER_LINEAR) |
|
alpha = cv2.resize(alpha, (scale_w, scale_h), interpolation=cv2.INTER_LINEAR) |
|
|
|
if scale_factor <= 1: |
|
offset_h, offset_w = np.random.randint(h - scale_h + 1), np.random.randint(w - scale_w + 1) |
|
new_fg[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w, :] = fg |
|
new_alpha[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w] = alpha |
|
else: |
|
offset_h, offset_w = np.random.randint(scale_h - h + 1), np.random.randint(scale_w - w + 1) |
|
new_fg = fg[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w, :] |
|
new_alpha = alpha[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w] |
|
|
|
sample['fg'], sample['alpha'] = new_fg, new_alpha |
|
return sample |
|
|
|
class GenBBox(object): |
|
def __init__(self, bbox_offset_factor = 0.1, random_crop_bbox = None, train_or_test = 'train', dataset_type = None, random_auto_matting=None): |
|
self.bbox_offset_factor = bbox_offset_factor |
|
self.random_crop_bbox = random_crop_bbox |
|
self.train_or_test = train_or_test |
|
self.dataset_type = dataset_type |
|
self.random_auto_matting = random_auto_matting |
|
|
|
def __call__(self, sample): |
|
|
|
alpha = sample['alpha'] |
|
indices = torch.nonzero(alpha[0], as_tuple=True) |
|
|
|
if len(indices[0]) > 0: |
|
|
|
min_x, min_y = torch.min(indices[1]), torch.min(indices[0]) |
|
max_x, max_y = torch.max(indices[1]), torch.max(indices[0]) |
|
|
|
if self.random_crop_bbox is not None and np.random.uniform(0, 1) < self.random_crop_bbox: |
|
ori_h_w = (sample['alpha'].shape[-2], sample['alpha'].shape[-1]) |
|
sample['alpha'] = F.interpolate(sample['alpha'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0] |
|
sample['image'] = F.interpolate(sample['image'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0] |
|
sample['trimap'] = F.interpolate(sample['trimap'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='nearest')[0] |
|
bbox = torch.tensor([[0, 0, ori_h_w[1] - 1, ori_h_w[0] - 1]]) |
|
|
|
elif self.bbox_offset_factor != 0: |
|
bbox_w = max(1, max_x - min_x) |
|
bbox_h = max(1, max_y - min_y) |
|
offset_w = math.ceil(self.bbox_offset_factor * bbox_w) |
|
offset_h = math.ceil(self.bbox_offset_factor * bbox_h) |
|
|
|
min_x = max(0, min_x + np.random.randint(-offset_w, offset_w)) |
|
max_x = min(alpha.shape[2] - 1, max_x + np.random.randint(-offset_w, offset_w)) |
|
min_y = max(0, min_y + np.random.randint(-offset_h, offset_h)) |
|
max_y = min(alpha.shape[1] - 1, max_y + np.random.randint(-offset_h, offset_h)) |
|
bbox = torch.tensor([[min_x, min_y, max_x, max_y]]) |
|
else: |
|
bbox = torch.tensor([[min_x, min_y, max_x, max_y]]) |
|
|
|
if self.random_auto_matting is not None and np.random.uniform(0, 1) < self.random_auto_matting: |
|
bbox = torch.tensor([[0, 0, alpha.shape[2] - 1, alpha.shape[1] - 1]]) |
|
|
|
else: |
|
bbox = torch.zeros(1, 4) |
|
|
|
sample['bbox'] = bbox.float() |
|
return sample |
|
|
|
class DataGenerator(Dataset): |
|
def __init__( |
|
self, |
|
data, |
|
phase="train", |
|
crop_size=512, |
|
remove_multi_fg=False, |
|
min_scale_fg_scale=None, |
|
max_scale_fg_scale=None, |
|
with_bbox = False, |
|
bbox_offset_factor = None, |
|
return_keys = None, |
|
random_crop_bbox = None, |
|
dataset_name = None, |
|
random_auto_matting = None, |
|
): |
|
self.phase = phase |
|
|
|
self.crop_size = crop_size |
|
self.remove_multi_fg = remove_multi_fg |
|
self.with_bbox = with_bbox |
|
self.bbox_offset_factor = bbox_offset_factor |
|
self.alpha = data.alpha |
|
self.return_keys = return_keys |
|
self.random_crop_bbox = random_crop_bbox |
|
self.dataset_name = dataset_name |
|
self.random_auto_matting = random_auto_matting |
|
|
|
if self.phase == "train": |
|
self.fg = data.fg |
|
self.bg = data.bg |
|
self.merged = [] |
|
self.trimap = [] |
|
else: |
|
self.fg = [] |
|
self.bg = [] |
|
self.merged = data.merged |
|
self.trimap = data.trimap |
|
|
|
train_trans = [ |
|
RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5), |
|
GenMask(), |
|
CutMask(perturb_prob=CONFIG.data.cutmask_prob), |
|
RandomCrop((self.crop_size, self.crop_size)), |
|
RandomJitter(), |
|
Composite(), |
|
ToTensor(phase="train") |
|
] |
|
if min_scale_fg_scale is not None: |
|
train_trans.insert(0, ScaleFg(min_scale_fg_scale, max_scale_fg_scale)) |
|
if self.with_bbox: |
|
train_trans.append(GenBBox(bbox_offset_factor=self.bbox_offset_factor, random_crop_bbox=self.random_crop_bbox, random_auto_matting=self.random_auto_matting)) |
|
|
|
test_trans = [ OriginScale(), ToTensor() ] |
|
|
|
self.transform = { |
|
'train': |
|
transforms.Compose(train_trans), |
|
'val': |
|
transforms.Compose([ |
|
OriginScale(), |
|
ToTensor() |
|
]), |
|
'test': |
|
transforms.Compose(test_trans) |
|
}[phase] |
|
|
|
self.fg_num = len(self.fg) |
|
|
|
def select_keys(self, sample): |
|
new_sample = {} |
|
for key, val in sample.items(): |
|
if key in self.return_keys: |
|
new_sample[key] = val |
|
return new_sample |
|
|
|
def __getitem__(self, idx): |
|
if self.phase == "train": |
|
fg = cv2.imread(self.fg[idx % self.fg_num]) |
|
alpha = cv2.imread(self.alpha[idx % self.fg_num], 0).astype(np.float32)/255 |
|
bg = cv2.imread(self.bg[idx], 1) |
|
|
|
if not self.remove_multi_fg: |
|
fg, alpha, multi_fg = self._composite_fg(fg, alpha, idx) |
|
else: |
|
multi_fg = False |
|
image_name = os.path.split(self.fg[idx % self.fg_num])[-1] |
|
sample = {'fg': fg, 'alpha': alpha, 'bg': bg, 'image_name': image_name, 'multi_fg': multi_fg} |
|
|
|
else: |
|
image = cv2.imread(self.merged[idx]) |
|
alpha = cv2.imread(self.alpha[idx], 0)/255. |
|
trimap = cv2.imread(self.trimap[idx], 0) |
|
mask = (trimap >= 170).astype(np.float32) |
|
image_name = os.path.split(self.merged[idx])[-1] |
|
|
|
sample = {'image': image, 'alpha': alpha, 'trimap': trimap, 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape} |
|
|
|
sample = self.transform(sample) |
|
|
|
if self.return_keys is not None: |
|
sample = self.select_keys(sample) |
|
if self.dataset_name is not None: |
|
sample['dataset_name'] = self.dataset_name |
|
return sample |
|
|
|
def _composite_fg(self, fg, alpha, idx): |
|
|
|
multi_fg = False |
|
if np.random.rand() < 0.5: |
|
idx2 = np.random.randint(self.fg_num) + idx |
|
fg2 = cv2.imread(self.fg[idx2 % self.fg_num]) |
|
alpha2 = cv2.imread(self.alpha[idx2 % self.fg_num], 0).astype(np.float32)/255. |
|
h, w = alpha.shape |
|
fg2 = cv2.resize(fg2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) |
|
alpha2 = cv2.resize(alpha2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) |
|
|
|
alpha_tmp = 1 - (1 - alpha) * (1 - alpha2) |
|
if np.any(alpha_tmp < 1): |
|
fg = fg.astype(np.float32) * alpha[:,:,None] + fg2.astype(np.float32) * (1 - alpha[:,:,None]) |
|
|
|
alpha = alpha_tmp |
|
fg = fg.astype(np.uint8) |
|
multi_fg = True |
|
|
|
if np.random.rand() < 0.25: |
|
|
|
|
|
fg = cv2.resize(fg, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) |
|
alpha = cv2.resize(alpha, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) |
|
|
|
return fg, alpha, multi_fg |
|
|
|
def __len__(self): |
|
if self.phase == "train": |
|
return len(self.bg) |
|
else: |
|
return len(self.alpha) |
|
|
|
|
|
class ResziePad(object): |
|
|
|
def __init__(self, target_size=1024): |
|
self.target_size = target_size |
|
|
|
def __call__(self, sample): |
|
_, H, W = sample['image'].shape |
|
|
|
scale = self.target_size * 1.0 / max(H, W) |
|
new_H, new_W = H * scale, W * scale |
|
new_W = int(new_W + 0.5) |
|
new_H = int(new_H + 0.5) |
|
|
|
choice = {'image', 'trimap', 'alpha'} if 'trimap' in sample.keys() else {'image', 'alpha'} |
|
for key in choice: |
|
if key in {'image', 'trimap'}: |
|
sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='bilinear', align_corners=False)[0] |
|
else: |
|
|
|
sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='bilinear', align_corners=False)[0] |
|
padding = torch.zeros([sample[key].shape[0], self.target_size, self.target_size], dtype=sample[key].dtype, device=sample[key].device) |
|
padding[:, : new_H, : new_W] = sample[key] |
|
sample[key] = padding |
|
|
|
return sample |
|
|
|
|
|
class Cv2ResziePad(object): |
|
|
|
def __init__(self, target_size=1024): |
|
self.target_size = target_size |
|
|
|
def __call__(self, sample): |
|
H, W, _ = sample['image'].shape |
|
|
|
scale = self.target_size * 1.0 / max(H, W) |
|
new_H, new_W = H * scale, W * scale |
|
new_W = int(new_W + 0.5) |
|
new_H = int(new_H + 0.5) |
|
|
|
choice = {'image', 'trimap', 'alpha'} if 'trimap' in sample.keys() and sample['trimap'] is not None else {'image', 'alpha'} |
|
for key in choice: |
|
sample[key] = cv2.resize(sample[key], (new_W, new_H), interpolation=cv2.INTER_LINEAR) |
|
|
|
if key == 'image': |
|
padding = np.zeros([self.target_size, self.target_size, sample[key].shape[-1]], dtype=sample[key].dtype) |
|
padding[: new_H, : new_W, :] = sample[key] |
|
sample[key] = padding |
|
sample[key] = sample[key][:, :, ::-1].transpose((2, 0, 1)).astype(np.float32) |
|
else: |
|
padding = np.zeros([self.target_size, self.target_size], dtype=sample[key].dtype) |
|
padding[: new_H, : new_W] = sample[key] |
|
sample[key] = padding |
|
sample[key] = sample[key][None].astype(np.float32) |
|
sample[key] = torch.from_numpy(sample[key]) |
|
|
|
return sample |
|
|
|
|
|
class AdobeCompositionTest(Dataset): |
|
def __init__(self, data_dir, target_size=1024, multi_fg=None): |
|
self.data_dir = data_dir |
|
self.file_names = sorted(os.listdir(os.path.join(self.data_dir, 'merged'))) |
|
|
|
test_trans = [ |
|
ResziePad(target_size=target_size), |
|
GenBBox(bbox_offset_factor=0) |
|
] |
|
self.transform = transforms.Compose(test_trans) |
|
self.multi_fg = multi_fg |
|
|
|
def __len__(self): |
|
return len(self.file_names) |
|
|
|
def __getitem__(self, idx): |
|
phas = Image.open(os.path.join(self.data_dir, 'alpha_copy', self.file_names[idx])).convert('L') |
|
tris = Image.open(os.path.join(self.data_dir, 'trimaps', self.file_names[idx])) |
|
imgs = Image.open(os.path.join(self.data_dir, 'merged', self.file_names[idx])) |
|
sample = { |
|
'ori_h_w': (imgs.size[1], imgs.size[0]), |
|
'data_type': 'Adobe' |
|
} |
|
|
|
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) |
|
sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0 |
|
sample['image'] = torchvision.transforms.functional.to_tensor(imgs) |
|
sample['image_name'] = 'Adobe_' + self.file_names[idx] |
|
|
|
sample = self.transform(sample) |
|
sample['trimap'][sample['trimap'] < 85] = 0 |
|
sample['trimap'][sample['trimap'] >= 170] = 1 |
|
sample['trimap'][sample['trimap'] >= 85] = 0.5 |
|
|
|
if self.multi_fg is not None: |
|
sample['multi_fg'] = torch.tensor(self.multi_fg) |
|
|
|
return sample |
|
|
|
|
|
class SIMTest(Dataset): |
|
def __init__(self, data_dir, target_size=1024, multi_fg=None): |
|
self.data_dir = data_dir |
|
self.file_names = sorted(glob.glob(os.path.join(*[data_dir, '*', 'alpha', '*']))) |
|
test_trans = [ |
|
ResziePad(target_size=target_size), |
|
GenBBox(bbox_offset_factor=0) |
|
] |
|
self.transform = transforms.Compose(test_trans) |
|
self.multi_fg = multi_fg |
|
|
|
def __len__(self): |
|
return len(self.file_names) |
|
|
|
def __getitem__(self, idx): |
|
phas = Image.open(self.file_names[idx]).convert('L') |
|
|
|
imgs = Image.open(self.file_names[idx].replace('alpha', 'merged')) |
|
sample = { |
|
'ori_h_w': (imgs.size[1], imgs.size[0]), |
|
'data_type': 'SIM' |
|
} |
|
|
|
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) |
|
|
|
sample['image'] = torchvision.transforms.functional.to_tensor(imgs) |
|
sample['image_name'] = 'SIM_{}_{}'.format(self.file_names[idx].split('/')[-3], self.file_names[idx].split('/')[-1]) |
|
|
|
sample = self.transform(sample) |
|
|
|
|
|
|
|
|
|
if self.multi_fg is not None: |
|
sample['multi_fg'] = torch.tensor(self.multi_fg) |
|
|
|
return sample |
|
|
|
|
|
class RW100Test(Dataset): |
|
def __init__(self, data_dir, target_size=1024, multi_fg=None): |
|
self.data_dir = data_dir |
|
self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'mask', '*']))) |
|
|
|
self.name_to_idx = dict() |
|
for idx, file_name in enumerate(self.file_names): |
|
self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx |
|
|
|
test_trans = [ |
|
ResziePad(target_size=target_size), |
|
GenBBox(bbox_offset_factor=0, train_or_test='test', dataset_type='RW100') |
|
] |
|
self.transform = transforms.Compose(test_trans) |
|
self.multi_fg = multi_fg |
|
|
|
def __len__(self): |
|
return len(self.file_names) |
|
|
|
def __getitem__(self, idx): |
|
phas = Image.open(self.file_names[idx]).convert('L') |
|
imgs = Image.open(self.file_names[idx].replace('mask', 'image')[:-6] + '.jpg') |
|
sample = { |
|
'ori_h_w': (imgs.size[1], imgs.size[0]), |
|
'data_type': 'RW100' |
|
} |
|
|
|
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) |
|
sample['image'] = torchvision.transforms.functional.to_tensor(imgs) |
|
sample['image_name'] = 'RW100_' + self.file_names[idx].split('/')[-1] |
|
|
|
sample = self.transform(sample) |
|
|
|
if self.multi_fg is not None: |
|
sample['multi_fg'] = torch.tensor(self.multi_fg) |
|
|
|
return sample |
|
|
|
|
|
class AIM500Test(Dataset): |
|
def __init__(self, data_dir, target_size=1024, multi_fg=None): |
|
self.data_dir = data_dir |
|
self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'original', '*']))) |
|
|
|
self.name_to_idx = dict() |
|
for idx, file_name in enumerate(self.file_names): |
|
self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx |
|
|
|
test_trans = [ |
|
ResziePad(target_size=target_size), |
|
GenBBox(bbox_offset_factor=0) |
|
] |
|
self.transform = transforms.Compose(test_trans) |
|
self.multi_fg = multi_fg |
|
|
|
def __len__(self): |
|
return len(self.file_names) |
|
|
|
def __getitem__(self, idx): |
|
phas = Image.open(self.file_names[idx].replace('original', 'mask').replace('jpg', 'png')).convert('L') |
|
|
|
imgs = Image.open(self.file_names[idx]) |
|
sample = { |
|
'ori_h_w': (imgs.size[1], imgs.size[0]), |
|
'data_type': 'AIM500' |
|
} |
|
|
|
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) |
|
|
|
sample['image'] = torchvision.transforms.functional.to_tensor(imgs) |
|
sample['image_name'] = 'AIM500_' + self.file_names[idx].split('/')[-1] |
|
|
|
sample = self.transform(sample) |
|
|
|
|
|
|
|
|
|
if self.multi_fg is not None: |
|
sample['multi_fg'] = torch.tensor(self.multi_fg) |
|
|
|
return sample |
|
|
|
|
|
class RWP636Test(Dataset): |
|
def __init__(self, data_dir, target_size=1024, multi_fg=None): |
|
self.data_dir = data_dir |
|
self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'image', '*']))) |
|
|
|
self.name_to_idx = dict() |
|
for idx, file_name in enumerate(self.file_names): |
|
self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx |
|
|
|
test_trans = [ |
|
ResziePad(target_size=target_size), |
|
GenBBox(bbox_offset_factor=0) |
|
] |
|
self.transform = transforms.Compose(test_trans) |
|
self.multi_fg = multi_fg |
|
|
|
def __len__(self): |
|
return len(self.file_names) |
|
|
|
def __getitem__(self, idx): |
|
phas = Image.open(self.file_names[idx].replace('image', 'alpha').replace('jpg', 'png')).convert('L') |
|
imgs = Image.open(self.file_names[idx]) |
|
sample = { |
|
'ori_h_w': (imgs.size[1], imgs.size[0]), |
|
'data_type': 'RWP636' |
|
} |
|
|
|
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) |
|
sample['image'] = torchvision.transforms.functional.to_tensor(imgs) |
|
sample['image_name'] = 'RWP636_' + self.file_names[idx].split('/')[-1] |
|
|
|
sample = self.transform(sample) |
|
|
|
if self.multi_fg is not None: |
|
sample['multi_fg'] = torch.tensor(self.multi_fg) |
|
|
|
return sample |
|
|
|
|
|
class AM2KTest(Dataset): |
|
def __init__(self, data_dir, target_size=1024, multi_fg=None): |
|
self.data_dir = data_dir |
|
self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'validation/original', '*']))) |
|
test_trans = [ |
|
ResziePad(target_size=target_size), |
|
GenBBox(bbox_offset_factor=0) |
|
] |
|
self.transform = transforms.Compose(test_trans) |
|
self.multi_fg = multi_fg |
|
|
|
def __len__(self): |
|
return len(self.file_names) |
|
|
|
def __getitem__(self, idx): |
|
phas = Image.open(self.file_names[idx].replace('original', 'mask').replace('jpg', 'png')).convert('L') |
|
|
|
imgs = Image.open(self.file_names[idx]) |
|
sample = { |
|
'ori_h_w': (imgs.size[1], imgs.size[0]), |
|
'data_type': 'AM2K' |
|
} |
|
|
|
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) |
|
|
|
sample['image'] = torchvision.transforms.functional.to_tensor(imgs) |
|
sample['image_name'] = 'AM2K_' + self.file_names[idx].split('/')[-1] |
|
|
|
sample = self.transform(sample) |
|
|
|
|
|
|
|
|
|
if self.multi_fg is not None: |
|
sample['multi_fg'] = torch.tensor(self.multi_fg) |
|
|
|
return sample |
|
|
|
|
|
class P3M500Test(Dataset): |
|
def __init__(self, data_dir, target_size=1024, multi_fg=None): |
|
self.data_dir = data_dir |
|
self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'original_image', '*']))) |
|
|
|
self.name_to_idx = dict() |
|
for idx, file_name in enumerate(self.file_names): |
|
self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx |
|
|
|
test_trans = [ |
|
ResziePad(target_size=target_size), |
|
GenBBox(bbox_offset_factor=0) |
|
] |
|
self.transform = transforms.Compose(test_trans) |
|
self.multi_fg = multi_fg |
|
|
|
def __len__(self): |
|
return len(self.file_names) |
|
|
|
def __getitem__(self, idx): |
|
phas = Image.open(self.file_names[idx].replace('original_image', 'mask').replace('jpg', 'png')).convert('L') |
|
|
|
imgs = Image.open(self.file_names[idx]) |
|
sample = { |
|
'ori_h_w': (imgs.size[1], imgs.size[0]), |
|
'data_type': 'P3M500' |
|
} |
|
|
|
sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) |
|
|
|
sample['image'] = torchvision.transforms.functional.to_tensor(imgs) |
|
sample['image_name'] = 'P3M500_' + self.file_names[idx].split('/')[-1] |
|
|
|
sample = self.transform(sample) |
|
|
|
|
|
|
|
|
|
if self.multi_fg is not None: |
|
sample['multi_fg'] = torch.tensor(self.multi_fg) |
|
|
|
return sample |
|
|
|
|
|
class MattingTest(Dataset): |
|
def __init__( |
|
self, |
|
data_type, |
|
data_dir, |
|
image_sub_path, |
|
alpha_sub_path, |
|
trimpa_sub_path=None, |
|
target_size=1024, |
|
multi_fg=None, |
|
): |
|
self.data_type = data_type |
|
self.data_dir = data_dir |
|
|
|
self.image_paths = sorted(glob.glob(os.path.join(*[data_dir, image_sub_path]))) |
|
self.alpha_paths = sorted(glob.glob(os.path.join(*[data_dir, alpha_sub_path]))) |
|
self.trimpa_paths = sorted(glob.glob(os.path.join(*[data_dir, trimpa_sub_path]))) if trimpa_sub_path is not None else None |
|
|
|
self.name_to_idx = dict() |
|
for idx, file_name in enumerate(self.image_paths): |
|
self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx |
|
|
|
test_trans = [ |
|
Cv2ResziePad(target_size=target_size), |
|
GenBBox(bbox_offset_factor=0) |
|
] |
|
self.transform = transforms.Compose(test_trans) |
|
self.multi_fg = multi_fg |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
|
|
img = cv2.imread(self.image_paths[idx]) |
|
sample = { |
|
'image': img.astype(np.float32) / 255, |
|
'alpha': cv2.imread(self.alpha_paths[idx], 0).astype(np.float32) / 255, |
|
'trimap': cv2.imread(self.trimpa_paths[idx], 0) if self.trimpa_paths is not None else None, |
|
'ori_h_w': (img.shape[0], img.shape[1]), |
|
'data_type': self.data_type, |
|
'image_name': self.data_type + '_' + self.image_paths[idx].split('/')[-1] |
|
} |
|
|
|
sample = self.transform(sample) |
|
if self.trimpa_paths is not None: |
|
sample['trimap'][sample['trimap'] < 85] = 0 |
|
sample['trimap'][sample['trimap'] >= 170] = 1 |
|
sample['trimap'][sample['trimap'] >= 85] = 0.5 |
|
else: |
|
del sample['trimap'] |
|
|
|
if self.multi_fg is not None: |
|
sample['multi_fg'] = torch.tensor(self.multi_fg) |
|
|
|
return sample |
|
|
|
|
|
def adobe_composition_collate_fn(batch): |
|
new_batch = defaultdict(list) |
|
for sub_batch in batch: |
|
for key in sub_batch.keys(): |
|
new_batch[key].append(sub_batch[key]) |
|
for key in new_batch: |
|
if isinstance(new_batch[key][0], torch.Tensor): |
|
new_batch[key] = torch.stack(new_batch[key]) |
|
return dict(new_batch) |
|
|
|
|
|
def build_d2_test_dataloader( |
|
dataset, |
|
mapper=None, |
|
total_batch_size=None, |
|
local_batch_size=None, |
|
num_workers=0, |
|
collate_fn=None |
|
): |
|
|
|
assert (total_batch_size is None) != ( |
|
local_batch_size is None |
|
), "Either total_batch_size or local_batch_size must be specified" |
|
|
|
world_size = comm.get_world_size() |
|
|
|
if total_batch_size is not None: |
|
assert ( |
|
total_batch_size > 0 and total_batch_size % world_size == 0 |
|
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format( |
|
total_batch_size, world_size |
|
) |
|
batch_size = total_batch_size // world_size |
|
|
|
if local_batch_size is not None: |
|
batch_size = local_batch_size |
|
|
|
logger = logging.getLogger(__name__) |
|
if batch_size != 1: |
|
logger.warning( |
|
"When testing, batch size is set to 1. " |
|
"This is the only mode that is supported for d2." |
|
) |
|
|
|
return build_detection_test_loader( |
|
dataset=dataset, |
|
mapper=mapper, |
|
sampler=None, |
|
num_workers=num_workers, |
|
collate_fn=collate_fn, |
|
) |
|
|
|
|
|
class AdobeCompositionEvaluator(DatasetEvaluator): |
|
|
|
def __init__( |
|
self, |
|
save_eval_results_step=-1, |
|
output_dir=None, |
|
eval_dataset_type=['Adobe'], |
|
distributed=True, |
|
eval_w_sam_hq_mask = False, |
|
): |
|
|
|
self.save_eval_results_step = save_eval_results_step |
|
self.output_dir = output_dir |
|
self.eval_index = 0 |
|
self.eval_dataset_type = eval_dataset_type |
|
self.eval_w_sam_hq_mask = eval_w_sam_hq_mask |
|
|
|
self._distributed = distributed |
|
self._logger = logging.getLogger(__name__) |
|
|
|
def reset(self): |
|
self.eval_metric = dict() |
|
for i in self.eval_dataset_type: |
|
self.eval_metric[i + '_MSE'] = [] |
|
self.eval_metric[i + '_SAD'] = [] |
|
self.eval_metric[i + '_MAD'] = [] |
|
self.eval_metric[i + '_Grad'] = [] |
|
self.eval_metric[i + '_Conn'] = [] |
|
|
|
os.makedirs(self.output_dir, exist_ok=True) if self.output_dir is not None else None |
|
|
|
def process(self, inputs, outputs): |
|
""" |
|
Args: |
|
inputs: {'alpha', 'trimap', 'image', 'bbox', 'image_name'} |
|
outputs: [1, 1, H, W] 0. ~ 1. |
|
""" |
|
|
|
|
|
assert inputs['image'].shape[-1] == inputs['image'].shape[-2] == 1024 and len(inputs['ori_h_w']) == 1 |
|
inputs['ori_h_w'] = inputs['ori_h_w'][0] |
|
before_pad_h, before_pad_w = int(1024 / max(inputs['ori_h_w']) * inputs['ori_h_w'][0] + 0.5), int(1024 / max(inputs['ori_h_w']) * inputs['ori_h_w'][1] + 0.5) |
|
inputs['image'] = inputs['image'][:, :, :before_pad_h, :before_pad_w] |
|
inputs['alpha'] = inputs['alpha'][:, :, :before_pad_h, :before_pad_w] |
|
|
|
if self.eval_w_sam_hq_mask: |
|
outputs, samhq_low_res_masks = outputs[0][:, :, :before_pad_h, :before_pad_w], outputs[1][:, :, :before_pad_h, :before_pad_w] |
|
pred_alpha, label_alpha, samhq_low_res_masks = outputs.cpu().numpy(), inputs['alpha'].numpy(), (samhq_low_res_masks > 0).float().cpu() |
|
else: |
|
outputs = outputs[:, :, :before_pad_h, :before_pad_w] |
|
pred_alpha, label_alpha = outputs.cpu().numpy(), inputs['alpha'].numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert np.max(pred_alpha) <= 1 and np.max(label_alpha) <= 1 |
|
eval_pred = np.uint8(pred_alpha[0, 0] * 255.0 + 0.5) * 1.0 |
|
eval_gt = label_alpha[0, 0] * 255.0 |
|
|
|
detailmap = np.zeros_like(eval_gt) + 128 |
|
mse_loss_ = compute_mse_loss(eval_pred, eval_gt, detailmap) |
|
sad_loss_ = compute_sad_loss(eval_pred, eval_gt, detailmap)[0] |
|
mad_loss_ = compute_mad_loss(eval_pred, eval_gt, detailmap) |
|
grad_loss_ = compute_gradient_loss(eval_pred, eval_gt, detailmap) |
|
conn_loss_ = compute_connectivity_error(eval_pred, eval_gt, detailmap) |
|
|
|
self.eval_metric[inputs['data_type'][0] + '_MSE'].append(mse_loss_) |
|
self.eval_metric[inputs['data_type'][0] + '_SAD'].append(sad_loss_) |
|
self.eval_metric[inputs['data_type'][0] + '_MAD'].append(mad_loss_) |
|
self.eval_metric[inputs['data_type'][0] + '_Grad'].append(grad_loss_) |
|
self.eval_metric[inputs['data_type'][0] + '_Conn'].append(conn_loss_) |
|
|
|
|
|
if self.save_eval_results_step != -1 and self.eval_index % self.save_eval_results_step == 0: |
|
if self.eval_w_sam_hq_mask: |
|
self.save_vis_results(inputs, pred_alpha, samhq_low_res_masks) |
|
else: |
|
self.save_vis_results(inputs, pred_alpha) |
|
self.eval_index += 1 |
|
|
|
def save_vis_results(self, inputs, pred_alpha, samhq_low_res_masks=None): |
|
|
|
|
|
image = inputs['image'][0].permute(1, 2, 0) * 255.0 |
|
l, u, r, d = int(inputs['bbox'][0, 0, 0].item()), int(inputs['bbox'][0, 0, 1].item()), int(inputs['bbox'][0, 0, 2].item()), int(inputs['bbox'][0, 0, 3].item()) |
|
red_line = torch.tensor([[255., 0., 0.]], device=image.device, dtype=image.dtype) |
|
image[u: d, l, :] = red_line |
|
image[u: d, r, :] = red_line |
|
image[u, l: r, :] = red_line |
|
image[d, l: r, :] = red_line |
|
image = np.uint8(image.numpy()) |
|
|
|
|
|
save_results = [image] |
|
|
|
choice = [inputs['trimap'], torch.from_numpy(pred_alpha), inputs['alpha']] if 'trimap' in inputs.keys() else [torch.from_numpy(pred_alpha), inputs['alpha']] |
|
for val in choice: |
|
val = val[0].permute(1, 2, 0).repeat(1, 1, 3) * 255.0 + 0.5 |
|
val = np.uint8(val.numpy()) |
|
save_results.append(val) |
|
|
|
if samhq_low_res_masks is not None: |
|
save_results.append(np.uint8(samhq_low_res_masks[0].permute(1, 2, 0).repeat(1, 1, 3).numpy() * 255.0)) |
|
|
|
save_results = np.concatenate(save_results, axis=1) |
|
save_name = os.path.join(self.output_dir, inputs['image_name'][0]) |
|
Image.fromarray(save_results).save(save_name.replace('.jpg', '.png')) |
|
|
|
def evaluate(self): |
|
|
|
if self._distributed: |
|
comm.synchronize() |
|
eval_metric = comm.gather(self.eval_metric, dst=0) |
|
|
|
if not comm.is_main_process(): |
|
return {} |
|
|
|
merges_eval_metric = defaultdict(list) |
|
for sub_eval_metric in eval_metric: |
|
for key, val in sub_eval_metric.items(): |
|
merges_eval_metric[key] += val |
|
eval_metric = merges_eval_metric |
|
|
|
else: |
|
eval_metric = self.eval_metric |
|
|
|
eval_results = {} |
|
|
|
for key, val in eval_metric.items(): |
|
if len(val) != 0: |
|
|
|
|
|
|
|
|
|
|
|
eval_results[key] = np.array(val).mean() |
|
|
|
return eval_results |
|
|
|
|
|
if __name__ == '__main__': |
|
pass |