Spaces:
Build error
Build error
import gradio as gr | |
import mathutils | |
import math | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib | |
import matplotlib.cm as cmx | |
import os.path as osp | |
import h5py | |
import random | |
import torch | |
import torch.nn as nn | |
from GDANet_cls import GDANET | |
from DGCNN import DGCNN | |
with open('shape_names.txt') as f: | |
CLASS_NAME = f.read().splitlines() | |
model_gda = GDANET() | |
model_gda = nn.DataParallel(model_gda) | |
model_gda.load_state_dict(torch.load('./GDANet_WOLFMix.t7', map_location=torch.device('cpu'))) | |
model_gda.eval() | |
model_dgcnn = DGCNN() | |
model_dgcnn = nn.DataParallel(model_dgcnn) | |
model_dgcnn.load_state_dict(torch.load('./dgcnn.t7', map_location=torch.device('cpu'))) | |
model_dgcnn.eval() | |
def pyplot_draw_point_cloud(points, corruption): | |
rot1 = mathutils.Euler([-math.pi / 2, 0, 0]).to_matrix().to_3x3() | |
rot2 = mathutils.Euler([0, 0, math.pi]).to_matrix().to_3x3() | |
points = np.dot(points, rot1) | |
points = np.dot(points, rot2) | |
x, y, z = points[:, 0], points[:, 1], points[:, 2] | |
colorsMap = 'winter' | |
cs = y | |
cm = plt.get_cmap(colorsMap) | |
cNorm = matplotlib.colors.Normalize(vmin=-1, vmax=1) | |
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cm) | |
fig = plt.figure(figsize=(5, 5)) | |
ax = fig.add_subplot(111, projection='3d') | |
ax.scatter(x, y, z, c=scalarMap.to_rgba(cs)) | |
scalarMap.set_array(cs) | |
ax.set_xlim(-1, 1) | |
ax.set_ylim(-1, 1) | |
ax.set_zlim(-1, 1) | |
plt.axis('off') | |
plt.title(corruption, fontsize=30) | |
plt.tight_layout() | |
plt.savefig('visualization.png', bbox_inches='tight', dpi=200) | |
plt.close() | |
def load_dataset(corruption_idx, severity): | |
corruptions = [ | |
'clean', | |
'scale', | |
'jitter', | |
'rotate', | |
'dropout_global', | |
'dropout_local', | |
'add_global', | |
'add_local', | |
] | |
corruption_type = corruptions[corruption_idx] | |
if corruption_type == 'clean': | |
f = h5py.File(osp.join('modelnet_c', corruption_type + '.h5')) | |
else: | |
f = h5py.File(osp.join('modelnet_c', corruption_type + '_{}'.format(severity-1) + '.h5')) | |
data = f['data'][:].astype('float32') | |
label = f['label'][:].astype('int64') | |
f.close() | |
return data, label | |
def recognize_pcd(model, pcd): | |
pcd = torch.tensor(pcd).unsqueeze(0) | |
pcd = pcd.permute(0, 2, 1) | |
output = model(pcd) | |
prediction = output.softmax(-1).flatten() | |
_, top5_idx = torch.topk(prediction, 5) | |
return {CLASS_NAME[i]: float(prediction[i]) for i in top5_idx.tolist()} | |
def run(seed, corruption_idx, severity): | |
data, label = load_dataset(corruption_idx, severity) | |
sample_indx = int(seed) | |
pcd, cls = data[sample_indx], label[sample_indx] | |
pyplot_draw_point_cloud(pcd, CLASS_NAME[cls[0]]) | |
output = 'visualization.png' | |
return output, recognize_pcd(model_dgcnn, pcd), recognize_pcd(model_gda, pcd) | |
if __name__ == '__main__': | |
iface = gr.Interface( | |
fn=run, | |
inputs=[ | |
gr.components.Number(label='Sample Seed', precision=0), | |
gr.components.Radio( | |
['Clean', 'Scale', 'Jitter', 'Rotate', 'Drop Global', 'Drop Local', 'Add Global', 'Add Local'], | |
value='Clean', type="index", label='Corruption Type'), | |
gr.components.Slider(1, 5, step=1, label='Corruption severity'), | |
], | |
outputs=[ | |
gr.components.Image(type="file", label="Visualization"), | |
gr.components.Label(num_top_classes=5, label="Baseline (DGCNN) Prediction"), | |
gr.components.Label(num_top_classes=5, label="Ours (GDANet+WolfMix) Prediction") | |
], | |
live=False, | |
allow_flagging='never', | |
title="Benchmarking and Analyzing Point Cloud Classification under Corruptions [ICML 2022]", | |
description="Welcome to the demo of ModelNet-C! You can visualize various types of corrupted point clouds in ModelNet-C and see how our proposed techniques contribute to robust predicitions compared to baseline methods.", | |
examples=[ | |
[0, 'Jitter', 5], | |
[999, 'Drop Local', 5], | |
], | |
# css=".output-image, .image-preview {height: 500px !important}", | |
article="<p style='text-align: center'><a href='https://github.com/jiawei-ren/ModelNet-C' target='_blank'>ModelNet-C @ GitHub</a></p> " | |
) | |
iface.launch() | |