pokemon_model / README.md
NP-NP's picture
Update README.md
a721254 verified
metadata
library_name: transformers
base_model:
  - google/vit-base-patch16-224

Model Card for Pokémon Type Classification

This model leverages a Vision Transformer (ViT) to classify Pokémon images into 18 different types.

It was developed as part of the CS 310 Final Project and trained on a Pokémon image dataset.

Model Details

  • Developer: Xianglu (Steven) Zhu
  • Purpose: Pokémon type classification
  • Model Type: Vision Transformer (ViT) for image classification

Getting Started

Here’s how you can use the model for classification:

import torch
from PIL import Image
import torchvision.transforms as transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor

# Load the pretrained model and feature extractor
hf_model = ViTForImageClassification.from_pretrained("NP-NP/pokemon_model")
hf_feature_extractor = ViTFeatureExtractor.from_pretrained("NP-NP/pokemon_model")

# Define preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=hf_feature_extractor.image_mean, std=hf_feature_extractor.image_std)
])

# Mapping of labels to indices and vice versa
labels_dict = {
    'Grass': 0, 'Fire': 1, 'Water': 2, 'Bug': 3, 'Normal': 4, 'Poison': 5, 'Electric': 6,
    'Ground': 7, 'Fairy': 8, 'Fighting': 9, 'Psychic': 10, 'Rock': 11, 'Ghost': 12,
    'Ice': 13, 'Dragon': 14, 'Dark': 15, 'Steel': 16, 'Flying': 17
}
idx_to_label = {v: k for k, v in labels_dict.items()}

# Load and preprocess the image
image_path = "cute-pikachu-flowers-pokemon-desktop-wallpaper.jpg"
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0)  # shape: (1, 3, 224, 224)

# Make a prediction
hf_model.eval()
with torch.no_grad():
    outputs = hf_model(input_tensor)
    logits = outputs.logits
    predicted_class_idx = torch.argmax(logits, dim=1).item()

predicted_class = idx_to_label[predicted_class_idx]
print("Predicted Pokémon type:", predicted_class)