Model Card for Pokémon Type Classification
This model leverages a Vision Transformer (ViT) to classify Pokémon images into 18 different types.
It was developed as part of the CS 310 Final Project and trained on a Pokémon image dataset.
Model Details
- Developer: Xianglu (Steven) Zhu
- Purpose: Pokémon type classification
- Model Type: Vision Transformer (ViT) for image classification
Getting Started
Here’s how you can use the model for classification:
import torch
from PIL import Image
import torchvision.transforms as transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
# Load the pretrained model and feature extractor
hf_model = ViTForImageClassification.from_pretrained("NP-NP/pokemon_model")
hf_feature_extractor = ViTFeatureExtractor.from_pretrained("NP-NP/pokemon_model")
# Define preprocessing transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=hf_feature_extractor.image_mean, std=hf_feature_extractor.image_std)
])
# Mapping of labels to indices and vice versa
labels_dict = {
'Grass': 0, 'Fire': 1, 'Water': 2, 'Bug': 3, 'Normal': 4, 'Poison': 5, 'Electric': 6,
'Ground': 7, 'Fairy': 8, 'Fighting': 9, 'Psychic': 10, 'Rock': 11, 'Ghost': 12,
'Ice': 13, 'Dragon': 14, 'Dark': 15, 'Steel': 16, 'Flying': 17
}
idx_to_label = {v: k for k, v in labels_dict.items()}
# Load and preprocess the image
image_path = "cute-pikachu-flowers-pokemon-desktop-wallpaper.jpg"
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0) # shape: (1, 3, 224, 224)
# Make a prediction
hf_model.eval()
with torch.no_grad():
outputs = hf_model(input_tensor)
logits = outputs.logits
predicted_class_idx = torch.argmax(logits, dim=1).item()
predicted_class = idx_to_label[predicted_class_idx]
print("Predicted Pokémon type:", predicted_class)
- Downloads last month
- 36
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.
Model tree for NP-NP/pokemon_model
Base model
google/vit-base-patch16-224