File size: 2,571 Bytes
fafff42
 
 
 
 
 
4ac4e3b
fafff42
4ac4e3b
fafff42
 
 
4ac4e3b
 
 
fafff42
 
 
ea60882
 
4ac4e3b
 
 
 
fafff42
 
 
 
 
 
 
 
 
 
 
 
 
 
4ac4e3b
fafff42
 
 
 
 
 
 
4ac4e3b
 
ea60882
4ac4e3b
 
 
fafff42
ea60882
4ac4e3b
 
fafff42
 
4ac4e3b
ea60882
fafff42
4ac4e3b
fafff42
 
 
 
4ac4e3b
 
 
fafff42
 
 
 
4ac4e3b
 
 
 
 
 
fafff42
4ac4e3b
 
0c50f29
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import gradio as gr
import numpy as np
import glob
import warnings
import pandas as pd
import matplotlib.pyplot as plt

from utils import OrthogonalRegularizer
from huggingface_hub.keras_mixin import from_pretrained_keras

# load model
model = from_pretrained_keras(
    "keras-io/pointnet_segmentation", custom_objects={"OrthogonalRegularizer": OrthogonalRegularizer}
)

# Examples
samples = []
input_images = glob.glob("asset/source/*.csv")
examples = [[im] for im in input_images]
LABELS = ["wing", "body", "tail", "engine"]
COLORS = ["blue", "green", "red", "pink"]


def visualize_data(point_cloud, labels, output_path=None):
    df = pd.DataFrame(
        data={
            "x": point_cloud[:, 0],
            "y": point_cloud[:, 1],
            "z": point_cloud[:, 2],
            "label": labels,
        }
    )
    fig = plt.figure(figsize=(15, 10))
    ax = plt.axes(projection="3d")
    for index, label in enumerate(LABELS):
        c_df = df[df["label"] == label]
        try:
            ax.scatter(c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index])
        except IndexError:
            pass
    ax.legend()
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        plt.savefig(output_path)


def inference(
    csv_file,
    output_path="asset/output",
    cpu=False,
):

    csv_path = csv_file.name
    im_name = csv_path.split("/")[-1].split(".")[0]

    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path, index_col=None)
        inputs = df[["x", "y", "z"]].values
        y_test = df.iloc[:, 3:].values  # TODO: show ground truth image if y_test is not None
    else:
        warnings.warn(f"{csv_path} not found for {im_path}")
        return

    preds = model.predict(np.expand_dims(inputs, 0))[0]
    label_map = LABELS + ["none"]
    visualize_data(inputs, [label_map[np.argmax(label)] for label in preds], f"{output_path}/{im_name}.png")
    return f"{output_path}/{im_name}.png"


article = "<div style='text-align: center;'><a href='https://nouamanetazi.me/' target='_blank'>Space by Nouamane Tazi</a><br><a href='https://keras.io/examples/vision/pointnet_segmentation' target='_blank'>Keras example by Soumik Rakshit, Sayak Paul</a></div>"

iface = gr.Interface(
    inference,  # main function
    inputs=[
        "file",
    ],
    outputs=[
        gr.outputs.Image(label="result"),  # generated image
    ],
    title="Point cloud segmentation with PointNet",
    article=article,
    examples=examples, cache_examples=True
).launch(enable_queue=True)