Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from PIL import Image | |
from models import IndividualLandmarkViT | |
from utils import VisualizeAttentionMaps | |
from utils.transform_utils import make_test_transforms | |
st.title("PdiscoFormer Part Discovery Visualizer") | |
# Instructions | |
st.write("First of all choose a model from the dropdown list." | |
"If you choose to upload an image, the part assignment will be visualized. " | |
"The model is trained to discover parts of the salient objects in the image depending on the training dataset.") | |
st.write("If you choose one of the CUB or NABirds models, please choose a bird image.") | |
st.write("If you choose one of the Flower models, please choose a flower image.") | |
st.write("If you choose one of the PartImageNet models, please choose an image of classes from PartImageNet like land animals/birds/cars/bottles/airplanes.") | |
model_options = ["ananthu-aniraj/pdiscoformer_cub_k_8", "ananthu-aniraj/pdiscoformer_cub_k_16", | |
"ananthu-aniraj/pdiscoformer_cub_k_4", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8", | |
"ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25", | |
"ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_50", | |
"ananthu-aniraj/pdiscoformer_flowers_k_2", "ananthu-aniraj/pdiscoformer_flowers_k_4", | |
"ananthu-aniraj/pdiscoformer_flowers_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_4", | |
"ananthu-aniraj/pdiscoformer_nabirds_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_11", | |
"ananthu-aniraj/pdiscoformer_pimagenet_seg_k_8", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_16", | |
"ananthu-aniraj/pdiscoformer_pimagenet_seg_k_25", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_41", | |
"ananthu-aniraj/pdiscoformer_pimagenet_seg_k_50"] | |
model_name = st.selectbox("Select a model", model_options) | |
if model_name is not None: | |
if "cub" in model_name or "nabirds" in model_name: | |
image_size = 518 | |
else: | |
image_size = 224 | |
# Set the device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load the model | |
model = IndividualLandmarkViT.from_pretrained(model_name, input_size=image_size).eval().to(device) | |
num_parts = model.num_landmarks | |
amap_vis = VisualizeAttentionMaps(num_parts=num_parts + 1, bg_label=num_parts) | |
test_transforms = make_test_transforms(image_size) | |
image_name = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) # Upload an image | |
if image_name is not None: | |
image = Image.open(image_name).convert("RGB") | |
image_tensor = test_transforms(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
maps, scores = model(image_tensor) | |
coloured_map = amap_vis.show_maps(image_tensor, maps) | |
st.image(coloured_map, caption="Attention Map", use_column_width=True) | |