Spaces:
Runtime error
Runtime error
File size: 2,571 Bytes
fafff42 4ac4e3b fafff42 4ac4e3b fafff42 4ac4e3b fafff42 ea60882 4ac4e3b fafff42 4ac4e3b fafff42 4ac4e3b ea60882 4ac4e3b fafff42 ea60882 4ac4e3b fafff42 4ac4e3b ea60882 fafff42 4ac4e3b fafff42 4ac4e3b fafff42 4ac4e3b fafff42 4ac4e3b 0c50f29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os
import gradio as gr
import numpy as np
import glob
import warnings
import pandas as pd
import matplotlib.pyplot as plt
from utils import OrthogonalRegularizer
from huggingface_hub.keras_mixin import from_pretrained_keras
# load model
model = from_pretrained_keras(
"keras-io/pointnet_segmentation", custom_objects={"OrthogonalRegularizer": OrthogonalRegularizer}
)
# Examples
samples = []
input_images = glob.glob("asset/source/*.csv")
examples = [[im] for im in input_images]
LABELS = ["wing", "body", "tail", "engine"]
COLORS = ["blue", "green", "red", "pink"]
def visualize_data(point_cloud, labels, output_path=None):
df = pd.DataFrame(
data={
"x": point_cloud[:, 0],
"y": point_cloud[:, 1],
"z": point_cloud[:, 2],
"label": labels,
}
)
fig = plt.figure(figsize=(15, 10))
ax = plt.axes(projection="3d")
for index, label in enumerate(LABELS):
c_df = df[df["label"] == label]
try:
ax.scatter(c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index])
except IndexError:
pass
ax.legend()
if output_path:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path)
def inference(
csv_file,
output_path="asset/output",
cpu=False,
):
csv_path = csv_file.name
im_name = csv_path.split("/")[-1].split(".")[0]
if os.path.exists(csv_path):
df = pd.read_csv(csv_path, index_col=None)
inputs = df[["x", "y", "z"]].values
y_test = df.iloc[:, 3:].values # TODO: show ground truth image if y_test is not None
else:
warnings.warn(f"{csv_path} not found for {im_path}")
return
preds = model.predict(np.expand_dims(inputs, 0))[0]
label_map = LABELS + ["none"]
visualize_data(inputs, [label_map[np.argmax(label)] for label in preds], f"{output_path}/{im_name}.png")
return f"{output_path}/{im_name}.png"
article = "<div style='text-align: center;'><a href='https://nouamanetazi.me/' target='_blank'>Space by Nouamane Tazi</a><br><a href='https://keras.io/examples/vision/pointnet_segmentation' target='_blank'>Keras example by Soumik Rakshit, Sayak Paul</a></div>"
iface = gr.Interface(
inference, # main function
inputs=[
"file",
],
outputs=[
gr.outputs.Image(label="result"), # generated image
],
title="Point cloud segmentation with PointNet",
article=article,
examples=examples, cache_examples=True
).launch(enable_queue=True)
|