Model Card for Pokémon Type Classification

This model leverages a Vision Transformer (ViT) to classify Pokémon images into 18 different types.

It was developed as part of the CS 310 Final Project and trained on a Pokémon image dataset.

Model Details

  • Developer: Xianglu (Steven) Zhu
  • Purpose: Pokémon type classification
  • Model Type: Vision Transformer (ViT) for image classification

Getting Started

Here’s how you can use the model for classification:

import torch
from PIL import Image
import torchvision.transforms as transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor

# Load the pretrained model and feature extractor
hf_model = ViTForImageClassification.from_pretrained("NP-NP/pokemon_model")
hf_feature_extractor = ViTFeatureExtractor.from_pretrained("NP-NP/pokemon_model")

# Define preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=hf_feature_extractor.image_mean, std=hf_feature_extractor.image_std)
])

# Mapping of labels to indices and vice versa
labels_dict = {
    'Grass': 0, 'Fire': 1, 'Water': 2, 'Bug': 3, 'Normal': 4, 'Poison': 5, 'Electric': 6,
    'Ground': 7, 'Fairy': 8, 'Fighting': 9, 'Psychic': 10, 'Rock': 11, 'Ghost': 12,
    'Ice': 13, 'Dragon': 14, 'Dark': 15, 'Steel': 16, 'Flying': 17
}
idx_to_label = {v: k for k, v in labels_dict.items()}

# Load and preprocess the image
image_path = "cute-pikachu-flowers-pokemon-desktop-wallpaper.jpg"
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0)  # shape: (1, 3, 224, 224)

# Make a prediction
hf_model.eval()
with torch.no_grad():
    outputs = hf_model(input_tensor)
    logits = outputs.logits
    predicted_class_idx = torch.argmax(logits, dim=1).item()

predicted_class = idx_to_label[predicted_class_idx]
print("Predicted Pokémon type:", predicted_class)
Downloads last month
36
Safetensors
Model size
85.8M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for NP-NP/pokemon_model

Finetuned
(537)
this model