Spaces:
Build error
Build error
File size: 2,299 Bytes
97069e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import matplotlib.pyplot as plt
import numpy
def plot_tensor_images(data, **kwargs):
data = ((data + 1) / 2 * 255).permute(0, 2, 3, 1).byte().cpu().numpy()
width = int(numpy.ceil(numpy.sqrt(data.shape[0])))
height = int(numpy.ceil(data.shape[0] / float(width)))
kwargs = dict(kwargs)
margin = 0.01
if 'figsize' not in kwargs:
# Size figure to one display pixel per data pixel
dpi = plt.rcParams['figure.dpi']
kwargs['figsize'] = (
(1 + margin) * (width * data.shape[2] / dpi),
(1 + margin) * (height * data.shape[1] / dpi))
f, axarr = plt.subplots(height, width, **kwargs)
if len(numpy.shape(axarr)) == 0:
axarr = numpy.array([[axarr]])
if len(numpy.shape(axarr)) == 1:
axarr = axarr[None,:]
for i, im in enumerate(data):
ax = axarr[i // width, i % width]
ax.imshow(data[i])
ax.axis('off')
for i in range(i, width * height):
ax = axarr[i // width, i % width]
ax.axis('off')
plt.subplots_adjust(wspace=margin, hspace=margin,
left=0, right=1, bottom=0, top=1)
plt.show()
def plot_max_heatmap(data, shape=None, **kwargs):
if shape is None:
shape = data.shape[2:]
data = data.max(1)[0].cpu().numpy()
vmin = data.min()
vmax = data.max()
width = int(numpy.ceil(numpy.sqrt(data.shape[0])))
height = int(numpy.ceil(data.shape[0] / float(width)))
kwargs = dict(kwargs)
margin = 0.01
if 'figsize' not in kwargs:
# Size figure to one display pixel per data pixel
dpi = plt.rcParams['figure.dpi']
kwargs['figsize'] = (
width * shape[1] / dpi, height * shape[0] / dpi)
f, axarr = plt.subplots(height, width, **kwargs)
if len(numpy.shape(axarr)) == 0:
axarr = numpy.array([[axarr]])
if len(numpy.shape(axarr)) == 1:
axarr = axarr[None,:]
for i, im in enumerate(data):
ax = axarr[i // width, i % width]
img = ax.imshow(data[i], vmin=vmin, vmax=vmax, cmap='hot')
ax.axis('off')
for i in range(i, width * height):
ax = axarr[i // width, i % width]
ax.axis('off')
plt.subplots_adjust(wspace=margin, hspace=margin,
left=0, right=1, bottom=0, top=1)
plt.show()
|