X / app.py
jeduardogruiz's picture
Create app.py
4929bfb verified
import os
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
import torch.nn.functional as F
from gradio.themes.utils import sizes
from PIL import Image
from torchvision import transforms
import tempfile
class Config:
ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets')
CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
CHECKPOINTS = {
"0.3b": "sapiens_0.3b_normal_render_people_epoch_66_torchscript.pt2",
"0.6b": "sapiens_0.6b_normal_render_people_epoch_200_torchscript.pt2",
"1b": "sapiens_1b_normal_render_people_epoch_115_torchscript.pt2",
"2b": "sapiens_2b_normal_render_people_epoch_70_torchscript.pt2",
}
SEG_CHECKPOINTS = {
"fg-bg-1b (recommended)": "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2",
"no-bg-removal": None,
"part-seg-1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2",
}
class ModelManager:
@staticmethod
def load_model(checkpoint_name: str):
if checkpoint_name is None:
return None
checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, checkpoint_name)
model = torch.jit.load(checkpoint_path)
model.eval()
model.to("cuda")
return model
@staticmethod
@torch.inference_mode()
def run_model(model, input_tensor, height, width):
output = model(input_tensor)
return F.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
class ImageProcessor:
def __init__(self):
self.transform_fn = transforms.Compose([
transforms.Resize((1024, 768)),
transforms.ToTensor(),
transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255], std=[58.5/255, 57.0/255, 57.5/255]),
])
@spaces.GPU
def process_image(self, image: Image.Image, normal_model_name: str, seg_model_name: str):
# Load models here instead of storing them as class attributes
normal_model = ModelManager.load_model(Config.CHECKPOINTS[normal_model_name])
input_tensor = self.transform_fn(image).unsqueeze(0).to("cuda")
# Run normal estimation
normal_output = ModelManager.run_model(normal_model, input_tensor, image.height, image.width)
normal_map = normal_output.squeeze().cpu().numpy().transpose(1, 2, 0)
# Create a copy of the normal map for visualization
normal_map_vis = normal_map.copy()
# Run segmentation
if seg_model_name != "no-bg-removal":
seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name])
seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width)
seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0]
# Apply segmentation mask to normal maps
normal_map[seg_mask == 0] = np.nan # Set background to NaN for NPY file
normal_map_vis[seg_mask == 0] = -1 # Set background to -1 for visualization
# Normalize and visualize normal map
normal_map_vis = self.visualize_normal_map(normal_map_vis)
# Create downloadable .npy file
npy_path = tempfile.mktemp(suffix='.npy')
np.save(npy_path, normal_map)
return Image.fromarray(normal_map_vis), npy_path
@staticmethod
def visualize_normal_map(normal_map):
normal_map_norm = np.linalg.norm(normal_map, axis=-1, keepdims=True)
normal_map_normalized = normal_map / (normal_map_norm + 1e-5)
normal_map_vis = ((normal_map_normalized + 1) / 2 * 255).astype(np.uint8)
return normal_map_vis
class GradioInterface:
def __init__(self):
self.image_processor = ImageProcessor()
def create_interface(self):
app_styles = """
<style>
/* Global Styles */
body, #root {
font-family: Helvetica, Arial, sans-serif;
background-color: #1a1a1a;
color: #fafafa;
}
/* Header Styles */
.app-header {
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
padding: 24px;
border-radius: 8px;
margin-bottom: 24px;
text-align: center;
}
.app-title {
font-size: 48px;
margin: 0;
color: #fafafa;
}
.app-subtitle {
font-size: 24px;
margin: 8px 0 16px;
color: #fafafa;
}
.app-description {
font-size: 16px;
line-height: 1.6;
opacity: 0.8;
margin-bottom: 24px;
}
/* Button Styles */
.publication-links {
display: flex;
justify-content: center;
flex-wrap: wrap;
gap: 8px;
margin-bottom: 16px;
}
.publication-link {
display: inline-flex;
align-items: center;
padding: 8px 16px;
background-color: #333;
color: #fff !important;
text-decoration: none !important;
border-radius: 20px;
font-size: 14px;
transition: background-color 0.3s;
}
.publication-link:hover {
background-color: #555;
}
.publication-link i {
margin-right: 8px;
}
/* Content Styles */
.content-container {
background-color: #2a2a2a;
border-radius: 8px;
padding: 24px;
margin-bottom: 24px;
}
/* Image Styles */
.image-preview img {
max-width: 100%;
max-height: 512px;
margin: 0 auto;
border-radius: 4px;
display: block;
}
/* Control Styles */
.control-panel {
background-color: #333;
padding: 16px;
border-radius: 8px;
margin-top: 16px;
}
/* Gradio Component Overrides */
.gr-button {
background-color: #4a4a4a;
color: #fff;
border: none;
border-radius: 4px;
padding: 8px 16px;
cursor: pointer;
transition: background-color 0.3s;
}
.gr-button:hover {
background-color: #5a5a5a;
}
.gr-input, .gr-dropdown {
background-color: #3a3a3a;
color: #fff;
border: 1px solid #4a4a4a;
border-radius: 4px;
padding: 8px;
}
.gr-form {
background-color: transparent;
}
.gr-panel {
border: none;
background-color: transparent;
}
/* Override any conflicting styles from Bulma */
.button.is-normal.is-rounded.is-dark {
color: #fff !important;
text-decoration: none !important;
}
</style>
"""
header_html = f"""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
{app_styles}
<div class="app-header">
<h1 class="app-title">Sapiens: Normal Estimation</h1>
<h2 class="app-subtitle">ECCV 2024 (Oral)</h2>
<p class="app-description">
Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images.
This demo showcases the finetuned normal estimation model. <br>
Checkout other normal estimation baselines to compare: <a href="https://huggingface.co/spaces/Stable-X/normal-estimation-arena" style="color: #3273dc;">normal-estimation-arena</a>
</p>
<div class="publication-links">
<a href="https://arxiv.org/abs/2408.12569" class="publication-link">
<i class="fas fa-file-pdf"></i>arXiv
</a>
<a href="https://github.com/facebookresearch/sapiens" class="publication-link">
<i class="fab fa-github"></i>Code
</a>
<a href="https://about.meta.com/realitylabs/codecavatars/sapiens/" class="publication-link">
<i class="fas fa-globe"></i>Meta
</a>
<a href="https://rawalkhirodkar.github.io/sapiens" class="publication-link">
<i class="fas fa-chart-bar"></i>Results
</a>
</div>
<div class="publication-links">
<a href="https://huggingface.co/spaces/facebook/sapiens_pose" class="publication-link">
<i class="fas fa-user"></i>Demo-Pose
</a>
<a href="https://huggingface.co/spaces/facebook/sapiens_seg" class="publication-link">
<i class="fas fa-puzzle-piece"></i>Demo-Seg
</a>
<a href="https://huggingface.co/spaces/facebook/sapiens_depth" class="publication-link">
<i class="fas fa-cube"></i>Demo-Depth
</a>
<a href="https://huggingface.co/spaces/facebook/sapiens_normal" class="publication-link">
<i class="fas fa-vector-square"></i>Demo-Normal
</a>
</div>
</div>
"""
def process_image(image, normal_model_name, seg_model_name):
result, npy_path = self.image_processor.process_image(image, normal_model_name, seg_model_name)
return result, npy_path
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
window.location.href = url.href;
}
}
"""
with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo:
gr.HTML(header_html)
with gr.Row(elem_classes="content-container"):
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview")
with gr.Row(elem_classes="control-panel"):
normal_model_name = gr.Dropdown(
label="Normal Model Size",
choices=list(Config.CHECKPOINTS.keys()),
value="1b",
)
seg_model_name = gr.Dropdown(
label="Background Removal Model",
choices=list(Config.SEG_CHECKPOINTS.keys()),
value="fg-bg-1b (recommended)",
)
example_model = gr.Examples(
inputs=input_image,
examples_per_page=14,
examples=[
os.path.join(Config.ASSETS_DIR, "images", img)
for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images"))
],
)
with gr.Column():
result_image = gr.Image(label="Normal Estimation Result", type="pil", elem_classes="image-preview")
npy_output = gr.File(label="Output (.npy). Note: Background normal is NaN.")
run_button = gr.Button("Run", elem_classes="gr-button")
run_button.click(
fn=process_image,
inputs=[input_image, normal_model_name, seg_model_name],
outputs=[result_image, npy_output],
)
return demo
def main():
# Configure CUDA if available
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
interface = GradioInterface()
demo = interface.create_interface()
demo.launch(share=False)
if __name__ == "__main__":
main()