|
import os |
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
import spaces |
|
import torch |
|
import torch.nn.functional as F |
|
from gradio.themes.utils import sizes |
|
from PIL import Image |
|
from torchvision import transforms |
|
import tempfile |
|
|
|
class Config: |
|
ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets') |
|
CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints") |
|
CHECKPOINTS = { |
|
"0.3b": "sapiens_0.3b_normal_render_people_epoch_66_torchscript.pt2", |
|
"0.6b": "sapiens_0.6b_normal_render_people_epoch_200_torchscript.pt2", |
|
"1b": "sapiens_1b_normal_render_people_epoch_115_torchscript.pt2", |
|
"2b": "sapiens_2b_normal_render_people_epoch_70_torchscript.pt2", |
|
} |
|
SEG_CHECKPOINTS = { |
|
"fg-bg-1b (recommended)": "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2", |
|
"no-bg-removal": None, |
|
"part-seg-1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2", |
|
} |
|
|
|
class ModelManager: |
|
@staticmethod |
|
def load_model(checkpoint_name: str): |
|
if checkpoint_name is None: |
|
return None |
|
checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, checkpoint_name) |
|
model = torch.jit.load(checkpoint_path) |
|
model.eval() |
|
model.to("cuda") |
|
return model |
|
|
|
@staticmethod |
|
@torch.inference_mode() |
|
def run_model(model, input_tensor, height, width): |
|
output = model(input_tensor) |
|
return F.interpolate(output, size=(height, width), mode="bilinear", align_corners=False) |
|
|
|
class ImageProcessor: |
|
def __init__(self): |
|
self.transform_fn = transforms.Compose([ |
|
transforms.Resize((1024, 768)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255], std=[58.5/255, 57.0/255, 57.5/255]), |
|
]) |
|
|
|
@spaces.GPU |
|
def process_image(self, image: Image.Image, normal_model_name: str, seg_model_name: str): |
|
|
|
normal_model = ModelManager.load_model(Config.CHECKPOINTS[normal_model_name]) |
|
input_tensor = self.transform_fn(image).unsqueeze(0).to("cuda") |
|
|
|
|
|
normal_output = ModelManager.run_model(normal_model, input_tensor, image.height, image.width) |
|
normal_map = normal_output.squeeze().cpu().numpy().transpose(1, 2, 0) |
|
|
|
|
|
normal_map_vis = normal_map.copy() |
|
|
|
|
|
if seg_model_name != "no-bg-removal": |
|
seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name]) |
|
seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width) |
|
seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0] |
|
|
|
|
|
normal_map[seg_mask == 0] = np.nan |
|
normal_map_vis[seg_mask == 0] = -1 |
|
|
|
|
|
normal_map_vis = self.visualize_normal_map(normal_map_vis) |
|
|
|
|
|
npy_path = tempfile.mktemp(suffix='.npy') |
|
np.save(npy_path, normal_map) |
|
|
|
return Image.fromarray(normal_map_vis), npy_path |
|
|
|
@staticmethod |
|
def visualize_normal_map(normal_map): |
|
normal_map_norm = np.linalg.norm(normal_map, axis=-1, keepdims=True) |
|
normal_map_normalized = normal_map / (normal_map_norm + 1e-5) |
|
normal_map_vis = ((normal_map_normalized + 1) / 2 * 255).astype(np.uint8) |
|
return normal_map_vis |
|
|
|
class GradioInterface: |
|
def __init__(self): |
|
self.image_processor = ImageProcessor() |
|
|
|
def create_interface(self): |
|
app_styles = """ |
|
<style> |
|
/* Global Styles */ |
|
body, #root { |
|
font-family: Helvetica, Arial, sans-serif; |
|
background-color: #1a1a1a; |
|
color: #fafafa; |
|
} |
|
|
|
/* Header Styles */ |
|
.app-header { |
|
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%); |
|
padding: 24px; |
|
border-radius: 8px; |
|
margin-bottom: 24px; |
|
text-align: center; |
|
} |
|
|
|
.app-title { |
|
font-size: 48px; |
|
margin: 0; |
|
color: #fafafa; |
|
} |
|
|
|
.app-subtitle { |
|
font-size: 24px; |
|
margin: 8px 0 16px; |
|
color: #fafafa; |
|
} |
|
|
|
.app-description { |
|
font-size: 16px; |
|
line-height: 1.6; |
|
opacity: 0.8; |
|
margin-bottom: 24px; |
|
} |
|
|
|
/* Button Styles */ |
|
.publication-links { |
|
display: flex; |
|
justify-content: center; |
|
flex-wrap: wrap; |
|
gap: 8px; |
|
margin-bottom: 16px; |
|
} |
|
|
|
.publication-link { |
|
display: inline-flex; |
|
align-items: center; |
|
padding: 8px 16px; |
|
background-color: #333; |
|
color: #fff !important; |
|
text-decoration: none !important; |
|
border-radius: 20px; |
|
font-size: 14px; |
|
transition: background-color 0.3s; |
|
} |
|
|
|
.publication-link:hover { |
|
background-color: #555; |
|
} |
|
|
|
.publication-link i { |
|
margin-right: 8px; |
|
} |
|
|
|
/* Content Styles */ |
|
.content-container { |
|
background-color: #2a2a2a; |
|
border-radius: 8px; |
|
padding: 24px; |
|
margin-bottom: 24px; |
|
} |
|
|
|
/* Image Styles */ |
|
.image-preview img { |
|
max-width: 100%; |
|
max-height: 512px; |
|
margin: 0 auto; |
|
border-radius: 4px; |
|
display: block; |
|
} |
|
|
|
/* Control Styles */ |
|
.control-panel { |
|
background-color: #333; |
|
padding: 16px; |
|
border-radius: 8px; |
|
margin-top: 16px; |
|
} |
|
|
|
/* Gradio Component Overrides */ |
|
.gr-button { |
|
background-color: #4a4a4a; |
|
color: #fff; |
|
border: none; |
|
border-radius: 4px; |
|
padding: 8px 16px; |
|
cursor: pointer; |
|
transition: background-color 0.3s; |
|
} |
|
|
|
.gr-button:hover { |
|
background-color: #5a5a5a; |
|
} |
|
|
|
.gr-input, .gr-dropdown { |
|
background-color: #3a3a3a; |
|
color: #fff; |
|
border: 1px solid #4a4a4a; |
|
border-radius: 4px; |
|
padding: 8px; |
|
} |
|
|
|
.gr-form { |
|
background-color: transparent; |
|
} |
|
|
|
.gr-panel { |
|
border: none; |
|
background-color: transparent; |
|
} |
|
|
|
/* Override any conflicting styles from Bulma */ |
|
.button.is-normal.is-rounded.is-dark { |
|
color: #fff !important; |
|
text-decoration: none !important; |
|
} |
|
</style> |
|
""" |
|
|
|
header_html = f""" |
|
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css"> |
|
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css"> |
|
{app_styles} |
|
<div class="app-header"> |
|
<h1 class="app-title">Sapiens: Normal Estimation</h1> |
|
<h2 class="app-subtitle">ECCV 2024 (Oral)</h2> |
|
<p class="app-description"> |
|
Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images. |
|
This demo showcases the finetuned normal estimation model. <br> |
|
Checkout other normal estimation baselines to compare: <a href="https://huggingface.co/spaces/Stable-X/normal-estimation-arena" style="color: #3273dc;">normal-estimation-arena</a> |
|
</p> |
|
<div class="publication-links"> |
|
<a href="https://arxiv.org/abs/2408.12569" class="publication-link"> |
|
<i class="fas fa-file-pdf"></i>arXiv |
|
</a> |
|
<a href="https://github.com/facebookresearch/sapiens" class="publication-link"> |
|
<i class="fab fa-github"></i>Code |
|
</a> |
|
<a href="https://about.meta.com/realitylabs/codecavatars/sapiens/" class="publication-link"> |
|
<i class="fas fa-globe"></i>Meta |
|
</a> |
|
<a href="https://rawalkhirodkar.github.io/sapiens" class="publication-link"> |
|
<i class="fas fa-chart-bar"></i>Results |
|
</a> |
|
</div> |
|
<div class="publication-links"> |
|
<a href="https://huggingface.co/spaces/facebook/sapiens_pose" class="publication-link"> |
|
<i class="fas fa-user"></i>Demo-Pose |
|
</a> |
|
<a href="https://huggingface.co/spaces/facebook/sapiens_seg" class="publication-link"> |
|
<i class="fas fa-puzzle-piece"></i>Demo-Seg |
|
</a> |
|
<a href="https://huggingface.co/spaces/facebook/sapiens_depth" class="publication-link"> |
|
<i class="fas fa-cube"></i>Demo-Depth |
|
</a> |
|
<a href="https://huggingface.co/spaces/facebook/sapiens_normal" class="publication-link"> |
|
<i class="fas fa-vector-square"></i>Demo-Normal |
|
</a> |
|
</div> |
|
</div> |
|
""" |
|
|
|
def process_image(image, normal_model_name, seg_model_name): |
|
result, npy_path = self.image_processor.process_image(image, normal_model_name, seg_model_name) |
|
return result, npy_path |
|
|
|
js_func = """ |
|
function refresh() { |
|
const url = new URL(window.location); |
|
if (url.searchParams.get('__theme') !== 'dark') { |
|
url.searchParams.set('__theme', 'dark'); |
|
window.location.href = url.href; |
|
} |
|
} |
|
""" |
|
|
|
with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo: |
|
gr.HTML(header_html) |
|
with gr.Row(elem_classes="content-container"): |
|
with gr.Column(): |
|
input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview") |
|
with gr.Row(elem_classes="control-panel"): |
|
normal_model_name = gr.Dropdown( |
|
label="Normal Model Size", |
|
choices=list(Config.CHECKPOINTS.keys()), |
|
value="1b", |
|
) |
|
seg_model_name = gr.Dropdown( |
|
label="Background Removal Model", |
|
choices=list(Config.SEG_CHECKPOINTS.keys()), |
|
value="fg-bg-1b (recommended)", |
|
) |
|
example_model = gr.Examples( |
|
inputs=input_image, |
|
examples_per_page=14, |
|
examples=[ |
|
os.path.join(Config.ASSETS_DIR, "images", img) |
|
for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images")) |
|
], |
|
) |
|
with gr.Column(): |
|
result_image = gr.Image(label="Normal Estimation Result", type="pil", elem_classes="image-preview") |
|
npy_output = gr.File(label="Output (.npy). Note: Background normal is NaN.") |
|
run_button = gr.Button("Run", elem_classes="gr-button") |
|
|
|
run_button.click( |
|
fn=process_image, |
|
inputs=[input_image, normal_model_name, seg_model_name], |
|
outputs=[result_image, npy_output], |
|
) |
|
|
|
return demo |
|
|
|
def main(): |
|
|
|
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
interface = GradioInterface() |
|
demo = interface.create_interface() |
|
demo.launch(share=False) |
|
|
|
if __name__ == "__main__": |
|
main() |