from __future__ import annotations import math import os import subprocess from pathlib import Path import gradio as gr import pygltflib import trimesh def convert_formats(path_input, target_ext): """ Converts an input 3D model under input path to the target extensions format and returns a path to that file. :param path_input: path to user input :param target_ext: target extension :return: path to the input 3D model stored in target format. """ path_input_base, ext = os.path.splitext(path_input) if ext == "." + target_ext: return path_input path_output = path_input_base + "." + target_ext if not os.path.exists(path_output): trimesh.load_mesh(path_input).export(path_output) return path_output def add_lights(path_input, path_output): glb = pygltflib.GLTF2().load(path_input) N = 3 # default max num lights in Babylon.js is 4 angle_step = 2 * math.pi / N lights_extension = { "lights": [ { "type": "directional", "color": [1.0, 1.0, 1.0], "intensity": 2.0 } for _ in range(N) ] } if "KHR_lights_punctual" not in glb.extensionsUsed: glb.extensionsUsed.append("KHR_lights_punctual") glb.extensions["KHR_lights_punctual"] = lights_extension light_nodes = [] for i in range(N): angle = i * angle_step rotation = [ 0.0, math.sin(angle / 2), 0.0, math.cos(angle / 2) ] node = { "rotation": rotation, "extensions": { "KHR_lights_punctual": { "light": i } } } light_nodes.append(node) light_node_indices = list(range(len(glb.nodes), len(glb.nodes) + N)) glb.nodes.extend(light_nodes) root_node_index = glb.scenes[glb.scene].nodes[0] root_node = glb.nodes[root_node_index] if hasattr(root_node, 'children'): root_node.children.extend(light_node_indices) else: root_node.children = light_node_indices glb.save(path_output) class Model3D(gr.Model3D): """ A simple overload of Gradio Model3D that accepts arbitrary 3D formats supported by trimesh. """ def postprocess(self, y: str | Path | None) -> dict[str, str] | None: if y is not None: y = convert_formats(y, "glb") out = super().postprocess(y) return out def breathe_new_life_into_3d_model(path_input, prompt): """ @inproceedings{wang2023breathing, title={Breathing New Life into 3D Assets with Generative Repainting}, author={Wang, Tianfu and Kanakis, Menelaos and Schindler, Konrad and Van Gool, Luc and Obukhov, Anton}, booktitle={Proceedings of the British Machine Vision Conference (BMVC)}, year={2023}, publisher={BMVA Press} } """ path_output_dir = path_input + ".output" os.makedirs(path_output_dir, exist_ok=True) path_input_ply = convert_formats(path_input, "ply") cmd = [ "bash", "/repainting_3d_assets/code/scripts/conda_run.sh", "/repainting_3d_assets", path_input_ply, path_output_dir, prompt, ] result = subprocess.run(cmd, env=os.environ, text=True) if result.returncode != 0: print(f"Output: {result.stdout}") print(f"Stderr: {result.stderr}") raise RuntimeError("Processing failed") path_output_glb = os.path.join(path_output_dir, "model_draco.glb") path_output_glb_vis = path_output_glb[:-4] + "_vis.glb" add_lights(path_output_glb, path_output_glb_vis) return path_output_glb_vis def run(): desc = """

badge-github-stars social

Repaint your 3D models with a text prompt, guided by a method from our BMVC'2023 Oral paper 'Breathing New Life into 3D Assets with Generative Repainting'. Simply drop a model into the left pane, specify your repainting preferences, and wait for the outcome (~20 min). Explore precomputed examples at the bottom, or follow the Project Website badge for additional precomputed models and comparison with other repainting techniques.

""" demo = gr.Interface( title="Repainting 3D Assets", description=desc, thumbnail="thumbnail.jpg", fn=breathe_new_life_into_3d_model, inputs=[ Model3D( camera_position=(30.0, 90.0, 3.0), elem_classes="viewport", label="Input Model", ), gr.Textbox(label="Text Prompt"), ], outputs=[ gr.Model3D( camera_position=(30.0, 90.0, 3.0), elem_classes="viewport", label="Repainted Model", ), ], examples=[ [ os.path.join(os.path.dirname(__file__), "files/horse.ply"), "pastel superhero unicorn", ], ], cache_examples=True, css=""" .viewport { aspect-ratio: 16/9; } """, allow_flagging="never", ) demo.queue().launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": run()