FashionGen / netdissect /broden.py
safi842's picture
Files for the app
29cdbe6
import os, errno, numpy, torch, csv, re, shutil, os, zipfile
from collections import OrderedDict
from torchvision.datasets.folder import default_loader
from torchvision import transforms
from scipy import ndimage
from urllib.request import urlopen
class BrodenDataset(torch.utils.data.Dataset):
'''
A multicategory segmentation data set.
Returns three streams:
(1) The image (3, h, w).
(2) The multicategory segmentation (labelcount, h, w).
(3) A bincount of pixels in the segmentation (labelcount).
Net dissect also assumes that the dataset object has three properties
with human-readable labels:
ds.labels = ['red', 'black', 'car', 'tree', 'grid', ...]
ds.categories = ['color', 'part', 'object', 'texture']
ds.label_category = [0, 0, 2, 2, 3, ...] # The category for each label
'''
def __init__(self, directory='dataset/broden', resolution=384,
split='train', categories=None,
transform=None, transform_segment=None,
download=False, size=None, include_bincount=True,
broden_version=1, max_segment_depth=6):
assert resolution in [224, 227, 384]
if download:
ensure_broden_downloaded(directory, resolution, broden_version)
self.directory = directory
self.resolution = resolution
self.resdir = os.path.join(directory, 'broden%d_%d' %
(broden_version, resolution))
self.loader = default_loader
self.transform = transform
self.transform_segment = transform_segment
self.include_bincount = include_bincount
# The maximum number of multilabel layers that coexist at an image.
self.max_segment_depth = max_segment_depth
with open(os.path.join(self.resdir, 'category.csv'),
encoding='utf-8') as f:
self.category_info = OrderedDict()
for row in csv.DictReader(f):
self.category_info[row['name']] = row
if categories is not None:
# Filter out unused categories
categories = set([c for c in categories if c in self.category_info])
for cat in list(self.category_info.keys()):
if cat not in categories:
del self.category_info[cat]
categories = list(self.category_info.keys())
self.categories = categories
# Filter out unneeded images.
with open(os.path.join(self.resdir, 'index.csv'),
encoding='utf-8') as f:
all_images = [decode_index_dict(r) for r in csv.DictReader(f)]
self.image = [row for row in all_images
if index_has_any_data(row, categories) and row['split'] == split]
if size is not None:
self.image = self.image[:size]
with open(os.path.join(self.resdir, 'label.csv'),
encoding='utf-8') as f:
self.label_info = build_dense_label_array([
decode_label_dict(r) for r in csv.DictReader(f)])
self.labels = [l['name'] for l in self.label_info]
# Build dense remapping arrays for labels, so that you can
# get dense ranges of labels for each category.
self.category_map = {}
self.category_unmap = {}
self.category_label = {}
for cat in self.categories:
with open(os.path.join(self.resdir, 'c_%s.csv' % cat),
encoding='utf-8') as f:
c_data = [decode_label_dict(r) for r in csv.DictReader(f)]
self.category_unmap[cat], self.category_map[cat] = (
build_numpy_category_map(c_data))
self.category_label[cat] = build_dense_label_array(
c_data, key='code')
self.num_labels = len(self.labels)
# Primary categories for each label is the category in which it
# appears with the maximum coverage.
self.label_category = numpy.zeros(self.num_labels, dtype=int)
for i in range(self.num_labels):
maxcoverage, self.label_category[i] = max(
(self.category_label[cat][self.category_map[cat][i]]['coverage']
if i < len(self.category_map[cat])
and self.category_map[cat][i] else 0, ic)
for ic, cat in enumerate(categories))
def __len__(self):
return len(self.image)
def __getitem__(self, idx):
record = self.image[idx]
# example record: {
# 'image': 'opensurfaces/25605.jpg', 'split': 'train',
# 'ih': 384, 'iw': 384, 'sh': 192, 'sw': 192,
# 'color': ['opensurfaces/25605_color.png'],
# 'object': [], 'part': [],
# 'material': ['opensurfaces/25605_material.png'],
# 'scene': [], 'texture': []}
image = self.loader(os.path.join(self.resdir, 'images',
record['image']))
segment = numpy.zeros(shape=(self.max_segment_depth,
record['sh'], record['sw']), dtype=int)
if self.include_bincount:
bincount = numpy.zeros(shape=(self.num_labels,), dtype=int)
depth = 0
for cat in self.categories:
for layer in record[cat]:
if isinstance(layer, int):
segment[depth,:,:] = layer
if self.include_bincount:
bincount[layer] += segment.shape[1] * segment.shape[2]
else:
png = numpy.asarray(self.loader(os.path.join(
self.resdir, 'images', layer)))
segment[depth,:,:] = png[:,:,0] + png[:,:,1] * 256
if self.include_bincount:
bincount += numpy.bincount(segment[depth,:,:].flatten(),
minlength=self.num_labels)
depth += 1
if self.transform:
image = self.transform(image)
if self.transform_segment:
segment = self.transform_segment(segment)
if self.include_bincount:
bincount[0] = 0
return (image, segment, bincount)
else:
return (image, segment)
def build_dense_label_array(label_data, key='number', allow_none=False):
'''
Input: set of rows with 'number' fields (or another field name key).
Output: array such that a[number] = the row with the given number.
'''
result = [None] * (max([d[key] for d in label_data]) + 1)
for d in label_data:
result[d[key]] = d
# Fill in none
if not allow_none:
example = label_data[0]
def make_empty(k):
return dict((c, k if c is key else type(v)())
for c, v in example.items())
for i, d in enumerate(result):
if d is None:
result[i] = dict(make_empty(i))
return result
def build_numpy_category_map(map_data, key1='code', key2='number'):
'''
Input: set of rows with 'number' fields (or another field name key).
Output: array such that a[number] = the row with the given number.
'''
results = list(numpy.zeros((max([d[key] for d in map_data]) + 1),
dtype=numpy.int16) for key in (key1, key2))
for d in map_data:
results[0][d[key1]] = d[key2]
results[1][d[key2]] = d[key1]
return results
def index_has_any_data(row, categories):
for c in categories:
for data in row[c]:
if data: return True
return False
def decode_label_dict(row):
result = {}
for key, val in row.items():
if key == 'category':
result[key] = dict((c, int(n))
for c, n in [re.match('^([^(]*)\(([^)]*)\)$', f).groups()
for f in val.split(';')])
elif key == 'name':
result[key] = val
elif key == 'syns':
result[key] = val.split(';')
elif re.match('^\d+$', val):
result[key] = int(val)
elif re.match('^\d+\.\d*$', val):
result[key] = float(val)
else:
result[key] = val
return result
def decode_index_dict(row):
result = {}
for key, val in row.items():
if key in ['image', 'split']:
result[key] = val
elif key in ['sw', 'sh', 'iw', 'ih']:
result[key] = int(val)
else:
item = [s for s in val.split(';') if s]
for i, v in enumerate(item):
if re.match('^\d+$', v):
item[i] = int(v)
result[key] = item
return result
class ScaleSegmentation:
'''
Utility for scaling segmentations, using nearest-neighbor zooming.
'''
def __init__(self, target_height, target_width):
self.target_height = target_height
self.target_width = target_width
def __call__(self, seg):
ratio = (1, self.target_height / float(seg.shape[1]),
self.target_width / float(seg.shape[2]))
return ndimage.zoom(seg, ratio, order=0)
def scatter_batch(seg, num_labels, omit_zero=True, dtype=torch.uint8):
'''
Utility for scattering semgentations into a one-hot representation.
'''
result = torch.zeros(*((seg.shape[0], num_labels,) + seg.shape[2:]),
dtype=dtype, device=seg.device)
result.scatter_(1, seg, 1)
if omit_zero:
result[:,0] = 0
return result
def ensure_broden_downloaded(directory, resolution, broden_version=1):
assert resolution in [224, 227, 384]
baseurl = 'http://netdissect.csail.mit.edu/data/'
dirname = 'broden%d_%d' % (broden_version, resolution)
if os.path.isfile(os.path.join(directory, dirname, 'index.csv')):
return # Already downloaded
zipfilename = 'broden1_%d.zip' % resolution
download_dir = os.path.join(directory, 'download')
os.makedirs(download_dir, exist_ok=True)
full_zipfilename = os.path.join(download_dir, zipfilename)
if not os.path.exists(full_zipfilename):
url = '%s/%s' % (baseurl, zipfilename)
print('Downloading %s' % url)
data = urlopen(url)
with open(full_zipfilename, 'wb') as f:
f.write(data.read())
print('Unzipping %s' % zipfilename)
with zipfile.ZipFile(full_zipfilename, 'r') as zip_ref:
zip_ref.extractall(directory)
assert os.path.isfile(os.path.join(directory, dirname, 'index.csv'))
def test_broden_dataset():
'''
Testing code.
'''
bds = BrodenDataset('dataset/broden', resolution=384,
transform=transforms.Compose([
transforms.Resize(224),
transforms.ToTensor()]),
transform_segment=transforms.Compose([
ScaleSegmentation(224, 224)
]),
include_bincount=True)
loader = torch.utils.data.DataLoader(bds, batch_size=100, num_workers=24)
for i in range(1,20):
print(bds.label[i]['name'],
list(bds.category.keys())[bds.primary_category[i]])
for i, (im, seg, bc) in enumerate(loader):
print(i, im.shape, seg.shape, seg.max(), bc.shape)
if __name__ == '__main__':
test_broden_dataset()