Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import subprocess | |
import tempfile | |
import os | |
import trimesh | |
import time | |
from datetime import datetime | |
import pytz | |
# Import potentially CUDA-initializing modules after 'spaces' | |
import torch | |
import src.depth_pro as depth_pro | |
import timm | |
import cv2 | |
print(f"Timm version: {timm.__version__}") | |
subprocess.run(["bash", "get_pretrained_models.sh"]) | |
def load_model_and_predict(image_path): | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model, transform = depth_pro.create_model_and_transforms() | |
model = model.to(device) | |
model.eval() | |
result = depth_pro.load_rgb(image_path) | |
if len(result) < 2: | |
raise ValueError(f"Unexpected result from load_rgb: {result}") | |
image = result[0] | |
f_px = result[-1] | |
print(f"Extracted focal length: {f_px}") | |
image = transform(image).to(device) | |
with torch.no_grad(): | |
prediction = model.infer(image, f_px=f_px) | |
depth = prediction["depth"].cpu().numpy() | |
focallength_px = prediction["focallength_px"] | |
return depth, focallength_px | |
def resize_image(image_path, max_size=1024): | |
""" | |
Resize the input image to ensure its largest dimension does not exceed max_size. | |
Maintains the aspect ratio and saves the resized image as a temporary PNG file. | |
Args: | |
image_path (str): Path to the input image. | |
max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024. | |
Returns: | |
str: Path to the resized temporary image file. | |
""" | |
with Image.open(image_path) as img: | |
# Calculate the resizing ratio while maintaining aspect ratio | |
ratio = max_size / max(img.size) | |
new_size = tuple([int(x * ratio) for x in img.size]) | |
# Resize the image using LANCZOS filter for high-quality downsampling | |
img = img.resize(new_size, Image.LANCZOS) | |
# Save the resized image to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: | |
img.save(temp_file, format="PNG") | |
return temp_file.name | |
# Increased duration to default 60 seconds | |
def generate_3d_model(depth, image_path, focallength_px, simplification_factor=1.0, smoothing_iterations=0, thin_threshold=0): | |
""" | |
Generate a textured 3D mesh from the depth map and the original image. | |
""" | |
try: | |
print("Starting 3D model generation") | |
# Load the RGB image and convert to a NumPy array | |
image = Image.open(image_path) | |
image_array = np.array(image) | |
# Ensure depth is a NumPy array | |
if isinstance(depth, torch.Tensor): | |
depth = depth.cpu().numpy() | |
# Resize depth to match image dimensions if necessary | |
if depth.shape != image_array.shape[:2]: | |
depth = cv2.resize(depth, (image_array.shape[1], image_array.shape[0]), interpolation=cv2.INTER_LINEAR) | |
height, width = depth.shape | |
print(f"3D model generation - Depth shape: {depth.shape}") | |
print(f"3D model generation - Image shape: {image_array.shape}") | |
# Compute camera intrinsic parameters | |
fx = fy = float(focallength_px) # Ensure focallength_px is a float | |
cx, cy = width / 2, height / 2 # Principal point at the image center | |
# Create a grid of (u, v) pixel coordinates | |
u = np.arange(0, width) | |
v = np.arange(0, height) | |
uu, vv = np.meshgrid(u, v) | |
# Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model | |
Z = depth.flatten() | |
X = ((uu.flatten() - cx) * Z) / fx | |
Y = ((vv.flatten() - cy) * Z) / fy | |
# Stack the coordinates to form vertices (X, Y, Z) | |
vertices = np.vstack((X, Y, Z)).T | |
# Normalize RGB colors to [0, 1] for vertex coloring | |
colors = image_array.reshape(-1, 3) / 255.0 | |
print("Generating faces") | |
# Generate faces by connecting adjacent vertices to form triangles | |
faces = [] | |
for i in range(height - 1): | |
for j in range(width - 1): | |
idx = i * width + j | |
# Triangle 1 | |
faces.append([idx, idx + width, idx + 1]) | |
# Triangle 2 | |
faces.append([idx + 1, idx + width, idx + width + 1]) | |
faces = np.array(faces) | |
print("Creating mesh") | |
# Create the mesh using Trimesh with vertex colors | |
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors, process=False) | |
# Mesh cleaning and improvement steps (only if not using default values) | |
if simplification_factor < 1.0 or smoothing_iterations > 0 or thin_threshold > 0: | |
print("Original mesh - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) | |
if simplification_factor < 1.0: | |
print("Simplifying mesh") | |
target_faces = int(len(mesh.faces) * simplification_factor) | |
mesh = mesh.simplify_quadric_decimation(face_count=target_faces) | |
print("After simplification - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) | |
if smoothing_iterations > 0: | |
print("Smoothing mesh") | |
for _ in range(smoothing_iterations): | |
mesh = mesh.smoothed() | |
print("After smoothing - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) | |
if thin_threshold > 0: | |
print("Removing thin features") | |
mesh = remove_thin_features(mesh, thickness_threshold=thin_threshold) | |
print("After removing thin features - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) | |
# Export the mesh to OBJ files with unique filenames | |
timestamp = int(time.time()) | |
view_model_path = f'view_model_{timestamp}.obj' | |
download_model_path = f'download_model_{timestamp}.obj' | |
print("Exporting to view") | |
mesh.export(view_model_path, include_texture=True) | |
print("Exporting to download") | |
mesh.export(download_model_path, include_texture=True) | |
# Save the texture image | |
texture_path = f'texture_{timestamp}.png' | |
image.save(texture_path) | |
print("Export completed") | |
return view_model_path, download_model_path, texture_path | |
except Exception as e: | |
print(f"Error in generate_3d_model: {str(e)}") | |
raise | |
def remove_thin_features(mesh, thickness_threshold=0.01): | |
""" | |
Remove thin features from the mesh. | |
""" | |
# Calculate edge lengths | |
edges = mesh.edges_unique | |
edge_points = mesh.vertices[edges] | |
edge_lengths = np.linalg.norm(edge_points[:, 0] - edge_points[:, 1], axis=1) | |
# Identify short edges | |
short_edges = edges[edge_lengths < thickness_threshold] | |
# Collapse short edges | |
for edge in short_edges: | |
try: | |
mesh.collapse_edge(edge) | |
except: | |
pass # Skip if edge collapse fails | |
# Remove any newly created degenerate faces | |
mesh.remove_degenerate_faces() | |
return mesh | |
# Increased duration to default 60 seconds | |
def regenerate_3d_model(depth_csv, image_path, focallength_px, simplification_factor, smoothing_iterations, thin_threshold): | |
# Load depth from CSV | |
depth = np.loadtxt(depth_csv, delimiter=',') | |
# Generate new 3D model with updated parameters | |
view_model_path, download_model_path, texture_path = generate_3d_model( | |
depth, image_path, focallength_px, | |
simplification_factor, smoothing_iterations, thin_threshold | |
) | |
print("regenerated!") | |
return view_model_path, download_model_path, texture_path | |
def predict_depth(input_image): | |
temp_file = None | |
try: | |
print(f"Input image type: {type(input_image)}") | |
print(f"Input image path: {input_image}") | |
temp_file = resize_image(input_image) | |
print(f"Resized image path: {temp_file}") | |
depth, focallength_px = load_model_and_predict(temp_file) | |
print(f"Raw depth type: {type(depth)}, focallength_px type: {type(focallength_px)}") | |
if depth.ndim != 2: | |
depth = depth.squeeze() | |
print(f"Depth map shape: {depth.shape}") | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(depth, cmap='gist_rainbow') | |
plt.colorbar(label='Depth [m]') | |
plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m') | |
plt.axis('off') | |
output_path = "depth_map.png" | |
plt.savefig(output_path) | |
plt.close() | |
raw_depth_path = "raw_depth_map.csv" | |
np.savetxt(raw_depth_path, depth.cpu().numpy() if isinstance(depth, torch.Tensor) else depth, delimiter=',') | |
print(f"Saved raw depth map to {raw_depth_path}") | |
focallength_px = float(focallength_px) | |
print(f"Converted focallength_px to float: {focallength_px}") | |
print("Depth map created!") | |
print(f"Returning - output_path: {output_path}, focallength_px: {focallength_px}, raw_depth_path: {raw_depth_path}, temp_file: {temp_file}") | |
return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, focallength_px | |
except Exception as e: | |
import traceback | |
error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" | |
print(error_message) | |
return None, error_message, None, None | |
finally: | |
if temp_file and os.path.exists(temp_file): | |
os.remove(temp_file) | |
print(f"Removed temporary file: {temp_file}") | |
def create_3d_model(depth_csv, image_path, focallength_px, simplification_factor, smoothing_iterations, thin_threshold): | |
try: | |
depth = np.loadtxt(depth_csv, delimiter=',') | |
# Check if the image file exists | |
if not os.path.exists(image_path): | |
raise FileNotFoundError(f"Image file not found: {image_path}") | |
print(f"Loading image from: {image_path}") | |
view_model_path, download_model_path, texture_path = generate_3d_model( | |
depth, image_path, focallength_px, | |
simplification_factor, smoothing_iterations, thin_threshold | |
) | |
print("3D model generated!") | |
return view_model_path, download_model_path, texture_path, "3D model created successfully!" | |
except Exception as e: | |
error_message = f"An error occurred during 3D model creation: {str(e)}" | |
print(error_message) | |
return None, None, None, error_message | |
def get_last_commit_timestamp(): | |
try: | |
# Get the timestamp in a format that includes timezone information | |
timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip() | |
# Parse the timestamp, including the timezone | |
dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S %z") | |
# Convert to UTC | |
dt_utc = dt.astimezone(pytz.UTC) | |
# Format the date as desired | |
return dt_utc.strftime("%Y-%m-%d %H:%M:%S UTC") | |
except Exception as e: | |
print(f"Error getting last commit timestamp: {str(e)}") | |
return "Unknown" | |
# Create the Gradio interface with appropriate input and output components. | |
last_updated = get_last_commit_timestamp() | |
with gr.Blocks() as iface: | |
gr.Markdown("# DepthPro Demo with 3D Visualization") | |
gr.Markdown( | |
"An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n" | |
"Forked from https://huggingface.co/spaces/akhaliq/depth-pro and model from https://huggingface.co/apple/DepthPro\n" | |
"**Instructions:**\n" | |
"1. Upload an image to generate the depth map.\n" | |
"2. Click 'Generate 3D Model' to create the 3D visualization.\n" | |
"3. Adjust parameters and click 'Regenerate 3D Model' to update the model.\n" | |
"4. Download the raw depth data as a CSV file or the 3D model as an OBJ file if desired.\n\n" | |
f"Last updated: {last_updated}" | |
) | |
with gr.Row(): | |
input_image = gr.Image(type="filepath", label="Input Image") | |
depth_map = gr.Image(type="filepath", label="Depth Map") | |
focal_length = gr.Textbox(label="Focal Length") | |
raw_depth_csv = gr.File(label="Download Raw Depth Map (CSV)") | |
generate_3d_button = gr.Button("Generate 3D Model") | |
with gr.Row(): | |
view_3d_model = gr.Model3D(label="View 3D Model") | |
download_3d_model = gr.File(label="Download 3D Model (OBJ)") | |
download_texture = gr.File(label="Download Texture (PNG)") | |
with gr.Row(): | |
simplification_factor = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.1, label="Simplification Factor (1.0 = No simplification)") | |
smoothing_iterations = gr.Slider(minimum=0, maximum=5, value=0, step=1, label="Smoothing Iterations (0 = No smoothing)") | |
thin_threshold = gr.Slider(minimum=0, maximum=0.1, value=0, step=0.001, label="Thin Feature Threshold (0 = No thin feature removal)") | |
regenerate_button = gr.Button("Regenerate 3D Model") | |
model_status = gr.Textbox(label="3D Model Status") | |
# Hidden components to store intermediate results | |
hidden_focal_length = gr.State() | |
input_image.change( | |
predict_depth, | |
inputs=[input_image], | |
outputs=[depth_map, focal_length, raw_depth_csv, hidden_focal_length] | |
) | |
generate_3d_button.click( | |
create_3d_model, | |
inputs=[raw_depth_csv, input_image, hidden_focal_length, simplification_factor, smoothing_iterations, thin_threshold], | |
outputs=[view_3d_model, download_3d_model, download_texture, model_status] | |
) | |
regenerate_button.click( | |
create_3d_model, | |
inputs=[raw_depth_csv, input_image, hidden_focal_length, simplification_factor, smoothing_iterations, thin_threshold], | |
outputs=[view_3d_model, download_3d_model, download_texture, model_status] | |
) | |
# Launch the Gradio interface with sharing enabled | |
iface.launch(share=True) |