|
import base64 |
|
import logging |
|
import os |
|
import random |
|
import sys |
|
|
|
import comfy.model_management |
|
import folder_paths |
|
import numpy as np |
|
import torch |
|
import trimesh |
|
from PIL import Image |
|
from trimesh.exchange import gltf |
|
|
|
sys.path.append(os.path.dirname(__file__)) |
|
from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE |
|
from spar3d.system import SPAR3D |
|
from spar3d.utils import foreground_crop |
|
|
|
SPAR3D_CATEGORY = "SPAR3D" |
|
SPAR3D_MODEL_NAME = "stabilityai/spar3d" |
|
|
|
|
|
class SPAR3DLoader: |
|
CATEGORY = SPAR3D_CATEGORY |
|
FUNCTION = "load" |
|
RETURN_NAMES = ("spar3d_model",) |
|
RETURN_TYPES = ("SPAR3D_MODEL",) |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return {"required": {}} |
|
|
|
def load(self): |
|
device = comfy.model_management.get_torch_device() |
|
model = SPAR3D.from_pretrained( |
|
SPAR3D_MODEL_NAME, |
|
config_name="config.yaml", |
|
weight_name="model.safetensors", |
|
) |
|
model.to(device) |
|
model.eval() |
|
|
|
return (model,) |
|
|
|
|
|
class SPAR3DPreview: |
|
CATEGORY = SPAR3D_CATEGORY |
|
FUNCTION = "preview" |
|
OUTPUT_NODE = True |
|
RETURN_TYPES = () |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": {"mesh": ("MESH",)}} |
|
|
|
def preview(self, mesh): |
|
glbs = [] |
|
for m in mesh: |
|
scene = trimesh.Scene(m) |
|
glb_data = gltf.export_glb(scene, include_normals=True) |
|
glb_base64 = base64.b64encode(glb_data).decode("utf-8") |
|
glbs.append(glb_base64) |
|
return {"ui": {"glbs": glbs}} |
|
|
|
|
|
class SPAR3DSampler: |
|
CATEGORY = SPAR3D_CATEGORY |
|
FUNCTION = "predict" |
|
RETURN_NAMES = ("mesh", "pointcloud") |
|
RETURN_TYPES = ("MESH", "POINTCLOUD") |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
remesh_choices = ["none"] |
|
if TRIANGLE_REMESH_AVAILABLE: |
|
remesh_choices.append("triangle") |
|
if QUAD_REMESH_AVAILABLE: |
|
remesh_choices.append("quad") |
|
|
|
opt_dict = { |
|
"mask": ("MASK",), |
|
"pointcloud": ("POINTCLOUD",), |
|
"target_type": (["none", "vertex", "face"],), |
|
"target_count": ( |
|
"INT", |
|
{"default": 1000, "min": 3, "max": 20000, "step": 1}, |
|
), |
|
"guidance_scale": ( |
|
"FLOAT", |
|
{"default": 3.0, "min": 1.0, "max": 5.0, "step": 0.05}, |
|
), |
|
"seed": ( |
|
"INT", |
|
{"default": 42, "min": 0, "max": 2**32 - 1, "step": 1}, |
|
), |
|
} |
|
if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE: |
|
opt_dict["remesh"] = (remesh_choices,) |
|
|
|
return { |
|
"required": { |
|
"model": ("SPAR3D_MODEL",), |
|
"image": ("IMAGE",), |
|
"foreground_ratio": ( |
|
"FLOAT", |
|
{"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.01}, |
|
), |
|
"texture_resolution": ( |
|
"INT", |
|
{"default": 1024, "min": 512, "max": 2048, "step": 256}, |
|
), |
|
}, |
|
"optional": opt_dict, |
|
} |
|
|
|
def predict( |
|
s, |
|
model, |
|
image, |
|
mask, |
|
foreground_ratio, |
|
texture_resolution, |
|
pointcloud=None, |
|
remesh="none", |
|
target_type="none", |
|
target_count=1000, |
|
guidance_scale=3.0, |
|
seed=42, |
|
): |
|
if image.shape[0] != 1: |
|
raise ValueError("Only one image can be processed at a time") |
|
|
|
vertex_count = ( |
|
-1 |
|
if target_type == "none" |
|
else (target_count // 2 if target_type == "face" else target_count) |
|
) |
|
|
|
pil_image = Image.fromarray( |
|
torch.clamp(torch.round(255.0 * image[0]), 0, 255) |
|
.type(torch.uint8) |
|
.cpu() |
|
.numpy() |
|
) |
|
|
|
if mask is not None: |
|
print("Using Mask") |
|
mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype( |
|
np.uint8 |
|
) |
|
mask_pil = Image.fromarray(mask_np, mode="L") |
|
pil_image.putalpha(mask_pil) |
|
else: |
|
if image.shape[3] != 4: |
|
print("No mask or alpha channel detected, Converting to RGBA") |
|
pil_image = pil_image.convert("RGBA") |
|
|
|
pil_image = foreground_crop(pil_image, foreground_ratio) |
|
|
|
model.cfg.guidance_scale = guidance_scale |
|
random.seed(seed) |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
|
|
print(remesh) |
|
with torch.no_grad(): |
|
with torch.autocast(device_type="cuda", dtype=torch.float16): |
|
if not TRIANGLE_REMESH_AVAILABLE and remesh == "triangle": |
|
raise ImportError( |
|
"Triangle remeshing requires gpytoolbox to be installed" |
|
) |
|
if not QUAD_REMESH_AVAILABLE and remesh == "quad": |
|
raise ImportError("Quad remeshing requires pynim to be installed") |
|
mesh, glob_dict = model.run_image( |
|
pil_image, |
|
bake_resolution=texture_resolution, |
|
pointcloud=pointcloud, |
|
remesh=remesh, |
|
vertex_count=vertex_count, |
|
) |
|
|
|
if mesh.vertices.shape[0] == 0: |
|
raise ValueError("No subject detected in the image") |
|
|
|
return ( |
|
[mesh], |
|
glob_dict["pointcloud"].view(-1).detach().cpu().numpy().tolist(), |
|
) |
|
|
|
|
|
class SPAR3DSave: |
|
CATEGORY = SPAR3D_CATEGORY |
|
FUNCTION = "save" |
|
OUTPUT_NODE = True |
|
RETURN_TYPES = () |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"mesh": ("MESH",), |
|
"filename_prefix": ("STRING", {"default": "SPAR3D"}), |
|
} |
|
} |
|
|
|
def __init__(self): |
|
self.type = "output" |
|
|
|
def save(self, mesh, filename_prefix): |
|
output_dir = folder_paths.get_output_directory() |
|
glbs = [] |
|
for idx, m in enumerate(mesh): |
|
scene = trimesh.Scene(m) |
|
glb_data = gltf.export_glb(scene, include_normals=True) |
|
logging.info(f"Generated GLB model with {len(glb_data)} bytes") |
|
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = ( |
|
folder_paths.get_save_image_path(filename_prefix, output_dir) |
|
) |
|
filename = filename.replace("%batch_num%", str(idx)) |
|
out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb") |
|
with open(out_path, "wb") as f: |
|
f.write(glb_data) |
|
glbs.append(base64.b64encode(glb_data).decode("utf-8")) |
|
return {"ui": {"glbs": glbs}} |
|
|
|
|
|
class SPAR3DPointCloudLoader: |
|
CATEGORY = SPAR3D_CATEGORY |
|
FUNCTION = "load_pointcloud" |
|
RETURN_TYPES = ("POINTCLOUD",) |
|
RETURN_NAMES = ("pointcloud",) |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"file": ("STRING", {"default": None}), |
|
} |
|
} |
|
|
|
def load_pointcloud(self, file): |
|
if file is None or file == "": |
|
return (None,) |
|
|
|
mesh = trimesh.load(file) |
|
|
|
|
|
vertices = mesh.vertices |
|
|
|
|
|
if mesh.visual.vertex_colors is not None: |
|
colors = ( |
|
mesh.visual.vertex_colors[:, :3] / 255.0 |
|
) |
|
else: |
|
colors = np.ones((len(vertices), 3)) |
|
|
|
|
|
point_cloud = [] |
|
for vertex, color in zip(vertices, colors): |
|
point_cloud.extend( |
|
[ |
|
float(vertex[0]), |
|
float(vertex[1]), |
|
float(vertex[2]), |
|
float(color[0]), |
|
float(color[1]), |
|
float(color[2]), |
|
] |
|
) |
|
|
|
return (point_cloud,) |
|
|
|
|
|
class SPAR3DPointCloudSaver: |
|
CATEGORY = SPAR3D_CATEGORY |
|
FUNCTION = "save_pointcloud" |
|
OUTPUT_NODE = True |
|
RETURN_TYPES = () |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"pointcloud": ("POINTCLOUD",), |
|
"filename_prefix": ("STRING", {"default": "SPAR3D"}), |
|
} |
|
} |
|
|
|
def save_pointcloud(self, pointcloud, filename_prefix): |
|
if pointcloud is None: |
|
return {"ui": {"text": "No point cloud data to save"}} |
|
|
|
|
|
points = np.array(pointcloud).reshape(-1, 6) |
|
|
|
|
|
vertex_array = np.zeros( |
|
len(points), |
|
dtype=[ |
|
("x", "f4"), |
|
("y", "f4"), |
|
("z", "f4"), |
|
("red", "u1"), |
|
("green", "u1"), |
|
("blue", "u1"), |
|
], |
|
) |
|
|
|
|
|
vertex_array["x"] = points[:, 0] |
|
vertex_array["y"] = points[:, 1] |
|
vertex_array["z"] = points[:, 2] |
|
|
|
vertex_array["red"] = (points[:, 3] * 255).astype(np.uint8) |
|
vertex_array["green"] = (points[:, 4] * 255).astype(np.uint8) |
|
vertex_array["blue"] = (points[:, 5] * 255).astype(np.uint8) |
|
|
|
|
|
ply_data = trimesh.PointCloud( |
|
vertices=points[:, :3], colors=points[:, 3:] * 255 |
|
) |
|
|
|
|
|
output_dir = folder_paths.get_output_directory() |
|
full_output_folder, filename, counter, subfolder, filename_prefix = ( |
|
folder_paths.get_save_image_path(filename_prefix, output_dir) |
|
) |
|
out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.ply") |
|
|
|
ply_data.export(out_path) |
|
|
|
return {"ui": {"text": f"Saved point cloud to {out_path}"}} |
|
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"SPAR3DLoader": "SPAR3D Loader", |
|
"SPAR3DPreview": "SPAR3D Preview", |
|
"SPAR3DSampler": "SPAR3D Sampler", |
|
"SPAR3DSave": "SPAR3D Save", |
|
"SPAR3DPointCloudLoader": "SPAR3D Point Cloud Loader", |
|
"SPAR3DPointCloudSaver": "SPAR3D Point Cloud Saver", |
|
} |
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"SPAR3DLoader": SPAR3DLoader, |
|
"SPAR3DPreview": SPAR3DPreview, |
|
"SPAR3DSampler": SPAR3DSampler, |
|
"SPAR3DSave": SPAR3DSave, |
|
"SPAR3DPointCloudLoader": SPAR3DPointCloudLoader, |
|
"SPAR3DPointCloudSaver": SPAR3DPointCloudSaver, |
|
} |
|
|
|
WEB_DIRECTORY = "./comfyui" |
|
|
|
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] |
|
|