Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import open3d_zerogpu_fix | |
import spaces | |
import re | |
from utils import read_pcd, render_point_cloud, render_pcd_file, set_seed | |
from inference.utils import get_legend | |
from inference.inference import segment_obj, get_heatmap | |
from huggingface_hub import login | |
import os | |
os.chdir("Pointcept/libs/pointops") | |
os.system("python setup.py install") | |
os.chdir("../../../") | |
login(token=os.getenv('hfkey')) | |
parts_dict = { | |
"fireplug": "bonnet of a fireplug,side cap of a fireplug,barrel of a fireplug,base of a fireplug", | |
"mickey": "ear,head,arms,hands,body,legs", | |
"motorvehicle": "wheel of a motor vehicle,seat of a motor vehicle,handle of a motor vehicle", | |
"teddy": "head,body,arms,legs", | |
"lamppost": "lighting of a lamppost,pole of a lamppost", | |
"shirt": "sleeve of a shirt,collar of a shirt,body of a shirt", | |
"capybara": "hat worn by a capybara,head,body,feet", | |
"corgi": "head,leg,body,ear", | |
"pushcar": "wheel,body,handle", | |
"plant": "pot,plant", | |
"chair": "back of chair,leg,seat" | |
} | |
source_dict = { | |
"fireplug":"objaverse", | |
"mickey":"objaverse", | |
"motorvehicle":"objaverse", | |
"teddy":"objaverse", | |
"lamppost":"objaverse", | |
"shirt":"objaverse", | |
"capybara": "wild", | |
"corgi": "wild", | |
"pushcar": "wild", | |
"plant": "wild", | |
"chair": "wild" | |
} | |
def run_predict(*args): | |
yield from predict(*args) | |
def predict(pcd_path, inference_mode, part_queries): | |
set_seed() | |
xyz, rgb, normal = read_pcd(pcd_path) | |
if inference_mode == "Segmentation": | |
parts = [part.strip(" ") for part in re.split(r'[,;.|]', part_queries)] | |
if len(parts)< 2: | |
raise gr.Error("For segmentation mode, please provide 2 or more parts", duration=5) | |
seg_rgb = segment_obj(xyz, rgb, normal, parts).cpu().numpy() | |
legend = get_legend(parts) | |
yield render_point_cloud(xyz, seg_rgb, legend=legend) | |
elif inference_mode == "Localization": | |
if "," in part_queries or ";" in part_queries or "." in part_queries: | |
raise gr.Error("For localization mode, please provide only one part", duration=5) | |
heatmap_rgb = get_heatmap(xyz, rgb, normal, part_queries).cpu().numpy() | |
yield render_point_cloud(xyz, heatmap_rgb) | |
else: | |
yield None | |
def on_select(evt: gr.SelectData): | |
obj_name = evt.value['image']['orig_name'][:-4] | |
src = source_dict[obj_name] | |
return [f"examples/{src}/{obj_name}.pcd", parts_dict[obj_name]] | |
with gr.Blocks(theme=gr.themes.Default(text_size="lg", radius_size="none")) as demo: | |
gr.HTML( | |
'''<h1 text-align="center">Find Any Part in 3D</h1> | |
<p style='font-size: 16px;'>This is a demo for Find3D: Find Any Part in 3D! Two modes are supported: <b>segmentation</b> and <b>localization</b>. | |
<br> | |
For <b>segmentation mode</b>, please provide multiple part queries in the "queries" text box, in the format of comma-separated string, such as "part1,part2,part3". | |
After hitting "Run", the model will segment the object into the provided parts. | |
<br> | |
For <b>localization mode</b>, please only provide <b>one query string</b> in the "queries" text box. After hitting "Run", the model will generate a heatmap for the provided query text. | |
Please click on the buttons below "Objaverse" and "In the Wild" for some examples. You can also upload your own .pcd files.</p> | |
<p style='font-size: 16px;'>Hint: | |
When uploading your own point cloud, please first close the existing point cloud by clicking on the "x" button. | |
<br> | |
We show some sample queries for the provided examples. When working with your own point cloud, feel free to rephrase the query (e.g. "part" vs "part of a object") to achieve better performance!</p> | |
''' | |
) | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=4): | |
file_upload = gr.File( | |
label="Upload Point Cloud File", | |
type="filepath", | |
file_types=[".pcd"], | |
value="examples/objaverse/lamppost.pcd" | |
) | |
inference_mode = gr.Radio( | |
choices=["Segmentation", "Localization"], | |
label="Inference Mode", | |
value="Segmentation", | |
) | |
part_queries = gr.Textbox( | |
label="Part Queries", | |
value="lighting of a lamppost,pole of a lamppost", | |
) | |
run_button = gr.Button( | |
value="Run", | |
variant="primary", | |
) | |
with gr.Column(scale=4): | |
input_image = gr.Image(label="Input Image", visible=False, type='pil', image_mode='RGBA', height=290) | |
input_point_cloud = gr.Plot(label="Input Point Cloud") | |
with gr.Column(scale=4): | |
output_point_cloud = gr.Plot(label="Output Result") | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=6): | |
title = gr.HTML('''<h1 text-align="center">Objaverse</h1> | |
<p style='font-size: 16px;'>Online 3D assets from Objaverse!</p> | |
''') | |
gallery_objaverse = gr.Gallery([("examples/objaverse/lamppost.jpg", "lamppost"), | |
("examples/objaverse/fireplug.jpg", "fireplug"), | |
("examples/objaverse/mickey.jpg", "Mickey"), | |
("examples/objaverse/motorvehicle.jpg", "motor vehicle"), | |
("examples/objaverse/teddy.jpg", "teddy bear"), | |
("examples/objaverse/shirt.jpg", "shirt")], | |
columns=3, | |
allow_preview=False) | |
gallery_objaverse.select(fn=on_select, | |
inputs=None, | |
outputs=[file_upload, part_queries]) | |
with gr.Column(scale=6): | |
title = gr.HTML("""<h1 text-align="center">In the Wild</h1> | |
<p style='font-size: 16px;'>Challenging in-the-wild reconstructions from iPhone photos & AI-generated images!</p> | |
""") | |
gallery_wild = gr.Gallery([("examples/wild/capybara.png", "DALLE-capybara"), | |
("examples/wild/corgi.jpg", "DALLE-corgi"), | |
("examples/wild/plant.jpg", "iPhone-plant"), | |
("examples/wild/pushcar.jpg", "iPhone-pushcar"), | |
("examples/wild/chair.jpg", "iPhone-chair")], | |
columns=3, | |
allow_preview=False) | |
gallery_wild.select(fn=on_select, | |
inputs=None, | |
outputs=[file_upload, part_queries]) | |
file_upload.change( | |
fn=render_pcd_file, | |
inputs=[file_upload], | |
outputs=[input_point_cloud], | |
) | |
run_button.click( | |
fn=run_predict, | |
inputs=[file_upload, inference_mode, part_queries], | |
outputs=[output_point_cloud], | |
) | |
demo.load( | |
fn=render_pcd_file, | |
inputs=[file_upload], | |
outputs=[input_point_cloud]) # initialize | |
demo.launch() | |