import streamlit as st import torch from PIL import Image from models import IndividualLandmarkViT from utils import VisualizeAttentionMaps from utils.transform_utils import make_test_transforms st.title("PdiscoFormer Part Discovery Visualizer") # Instructions st.write("First of all choose a model from the dropdown list." "If you choose to upload an image, the part assignment will be visualized. " "The model is trained to discover parts of the salient objects in the image depending on the training dataset.") st.write("If you choose one of the CUB or NABirds models, please choose a bird image.") st.write("If you choose one of the Flower models, please choose a flower image.") st.write("If you choose one of the PartImageNet models, please choose an image of classes from PartImageNet like land animals/birds/cars/bottles/airplanes.") model_options = ["ananthu-aniraj/pdiscoformer_cub_k_8", "ananthu-aniraj/pdiscoformer_cub_k_16", "ananthu-aniraj/pdiscoformer_cub_k_4", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25", "ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_50", "ananthu-aniraj/pdiscoformer_flowers_k_2", "ananthu-aniraj/pdiscoformer_flowers_k_4", "ananthu-aniraj/pdiscoformer_flowers_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_4", "ananthu-aniraj/pdiscoformer_nabirds_k_8", "ananthu-aniraj/pdiscoformer_nabirds_k_11", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_8", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_16", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_25", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_41", "ananthu-aniraj/pdiscoformer_pimagenet_seg_k_50"] model_name = st.selectbox("Select a model", model_options) if model_name is not None: if "cub" in model_name or "nabirds" in model_name: image_size = 518 else: image_size = 224 # Set the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the model model = IndividualLandmarkViT.from_pretrained(model_name, input_size=image_size).eval().to(device) num_parts = model.num_landmarks amap_vis = VisualizeAttentionMaps(num_parts=num_parts + 1, bg_label=num_parts) test_transforms = make_test_transforms(image_size) image_name = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) # Upload an image if image_name is not None: image = Image.open(image_name).convert("RGB") image_tensor = test_transforms(image).unsqueeze(0).to(device) with torch.no_grad(): maps, scores = model(image_tensor) coloured_map = amap_vis.show_maps(image_tensor, maps) st.image(coloured_map, caption="Attention Map", use_column_width=True)