metadata
library_name: transformers
base_model:
- google/vit-base-patch16-224
Model Card for Pokémon Type Classification
This model leverages a Vision Transformer (ViT) to classify Pokémon images into 18 different types.
It was developed as part of the CS 310 Final Project and trained on a Pokémon image dataset.
Model Details
- Developer: Xianglu (Steven) Zhu
- Purpose: Pokémon type classification
- Model Type: Vision Transformer (ViT) for image classification
Getting Started
Here’s how you can use the model for classification:
import torch
from PIL import Image
import torchvision.transforms as transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
# Load the pretrained model and feature extractor
hf_model = ViTForImageClassification.from_pretrained("NP-NP/pokemon_model")
hf_feature_extractor = ViTFeatureExtractor.from_pretrained("NP-NP/pokemon_model")
# Define preprocessing transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=hf_feature_extractor.image_mean, std=hf_feature_extractor.image_std)
])
# Mapping of labels to indices and vice versa
labels_dict = {
'Grass': 0, 'Fire': 1, 'Water': 2, 'Bug': 3, 'Normal': 4, 'Poison': 5, 'Electric': 6,
'Ground': 7, 'Fairy': 8, 'Fighting': 9, 'Psychic': 10, 'Rock': 11, 'Ghost': 12,
'Ice': 13, 'Dragon': 14, 'Dark': 15, 'Steel': 16, 'Flying': 17
}
idx_to_label = {v: k for k, v in labels_dict.items()}
# Load and preprocess the image
image_path = "cute-pikachu-flowers-pokemon-desktop-wallpaper.jpg"
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0) # shape: (1, 3, 224, 224)
# Make a prediction
hf_model.eval()
with torch.no_grad():
outputs = hf_model(input_tensor)
logits = outputs.logits
predicted_class_idx = torch.argmax(logits, dim=1).item()
predicted_class = idx_to_label[predicted_class_idx]
print("Predicted Pokémon type:", predicted_class)