toshas's picture
Initial commit
c184ad2
raw
history blame
6.42 kB
from __future__ import annotations
import math
import os
import subprocess
from pathlib import Path
import gradio as gr
import pygltflib
import trimesh
def convert_formats(path_input, target_ext):
"""
Converts an input 3D model under input path to the target extensions format and returns a path to that file.
:param path_input: path to user input
:param target_ext: target extension
:return: path to the input 3D model stored in target format.
"""
path_input_base, ext = os.path.splitext(path_input)
if ext == "." + target_ext:
return path_input
path_output = path_input_base + "." + target_ext
if not os.path.exists(path_output):
trimesh.load_mesh(path_input).export(path_output)
return path_output
def add_lights(path_input, path_output):
glb = pygltflib.GLTF2().load(path_input)
N = 3 # default max num lights in Babylon.js is 4
angle_step = 2 * math.pi / N
lights_extension = {
"lights": [
{
"type": "directional",
"color": [1.0, 1.0, 1.0],
"intensity": 2.0
}
for _ in range(N)
]
}
if "KHR_lights_punctual" not in glb.extensionsUsed:
glb.extensionsUsed.append("KHR_lights_punctual")
glb.extensions["KHR_lights_punctual"] = lights_extension
light_nodes = []
for i in range(N):
angle = i * angle_step
rotation = [
0.0,
math.sin(angle / 2),
0.0,
math.cos(angle / 2)
]
node = {
"rotation": rotation,
"extensions": {
"KHR_lights_punctual": {
"light": i
}
}
}
light_nodes.append(node)
light_node_indices = list(range(len(glb.nodes), len(glb.nodes) + N))
glb.nodes.extend(light_nodes)
root_node_index = glb.scenes[glb.scene].nodes[0]
root_node = glb.nodes[root_node_index]
if hasattr(root_node, 'children'):
root_node.children.extend(light_node_indices)
else:
root_node.children = light_node_indices
glb.save(path_output)
class Model3D(gr.Model3D):
"""
A simple overload of Gradio Model3D that accepts arbitrary 3D formats supported by trimesh.
"""
def postprocess(self, y: str | Path | None) -> dict[str, str] | None:
if y is not None:
y = convert_formats(y, "glb")
out = super().postprocess(y)
return out
def breathe_new_life_into_3d_model(path_input, prompt):
"""
@inproceedings{wang2023breathing,
title={Breathing New Life into 3D Assets with Generative Repainting},
author={Wang, Tianfu and Kanakis, Menelaos and Schindler, Konrad and Van Gool, Luc and Obukhov, Anton},
booktitle={Proceedings of the British Machine Vision Conference (BMVC)},
year={2023},
publisher={BMVA Press}
}
"""
path_output_dir = path_input + ".output"
os.makedirs(path_output_dir, exist_ok=True)
path_input_ply = convert_formats(path_input, "ply")
cmd = [
"bash",
"/repainting_3d_assets/code/scripts/conda_run.sh",
"/repainting_3d_assets",
path_input_ply,
path_output_dir,
prompt,
]
result = subprocess.run(cmd, env=os.environ, text=True)
if result.returncode != 0:
print(f"Output: {result.stdout}")
print(f"Stderr: {result.stderr}")
raise RuntimeError("Processing failed")
path_output_glb = os.path.join(path_output_dir, "model_draco.glb")
path_output_glb_vis = path_output_glb[:-4] + "_vis.glb"
add_lights(path_output_glb, path_output_glb_vis)
return path_output_glb_vis
def run():
desc = """
<p align="center">
<a title="Website" href="https://www.obukhov.ai/repainting_3d_assets" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
</a>
<a title="arXiv" href="https://arxiv.org/abs/2309.08523" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
</a>
<a title="Github" href="https://github.com/kongdai123/repainting_3d_assets" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/kongdai123/repainting_3d_assets?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
</a>
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a>
</p>
<p align="justify">
Repaint your 3D models with a text prompt, guided by a method from our BMVC'2023 Oral paper 'Breathing New Life
into 3D Assets with Generative Repainting'. Simply drop a model into the left pane, specify your repainting
preferences, and wait for the outcome (~20 min). Explore precomputed examples at the bottom, or follow the
Project Website badge for additional precomputed models and comparison with other repainting techniques.
</p>
"""
demo = gr.Interface(
title="Repainting 3D Assets",
description=desc,
thumbnail="thumbnail.jpg",
fn=breathe_new_life_into_3d_model,
inputs=[
Model3D(
camera_position=(30.0, 90.0, 3.0),
elem_classes="viewport",
label="Input Model",
),
gr.Textbox(label="Text Prompt"),
],
outputs=[
gr.Model3D(
camera_position=(30.0, 90.0, 3.0),
elem_classes="viewport",
label="Repainted Model",
),
],
examples=[
[
os.path.join(os.path.dirname(__file__), "files/horse.ply"),
"pastel superhero unicorn",
],
],
cache_examples=True,
css="""
.viewport {
aspect-ratio: 16/9;
}
""",
allow_flagging="never",
)
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
run()