nielsr HF staff commited on
Commit
91e2667
·
1 Parent(s): 5471909

Update to COCO example image

Browse files
Files changed (1) hide show
  1. README.md +10 -13
README.md CHANGED
@@ -31,24 +31,21 @@ fine-tuned versions on a task that interests you.
31
 
32
  ### How to use
33
 
34
- Here is how to use this model to classify an image of CIFAR-100 into one of the 1,000 ImageNet classes:
35
 
36
  ```python
37
  from transformers import ViTFeatureExtractor, ViTForImageClassification
38
- from datasets import load_dataset
39
- import numpy as np
40
-
 
41
  feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
42
- model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
43
-
44
- dataset = load_dataset("cifar100", split='test')
45
- image = np.asarray(dataset[2]['img'], dtype=np.uint8)
46
- image = np.moveaxis(image, source=-1, destination=0) # change from (H, W, C) to (C, H, W)
47
-
48
- pixel_values = feature_extractor(image)
49
- outputs = model(pixel_values)
50
  logits = outputs.logits
51
- predicted_class = logits.argmax(-1)
 
52
  ```
53
 
54
  Currently, both the feature extractor and model support PyTorch. Tensorflow and JAX/FLAX are coming soon, and the API of ViTFeatureExtractor might change.
 
31
 
32
  ### How to use
33
 
34
+ Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes:
35
 
36
  ```python
37
  from transformers import ViTFeatureExtractor, ViTForImageClassification
38
+ from PIL import Image
39
+ import requests
40
+ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
41
+ image = Image.open(requests.get(url, stream=True).raw)
42
  feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
43
+ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
44
+ inputs = feature_extractor(images=image)
45
+ outputs = model(**inputs)
 
 
 
 
 
46
  logits = outputs.logits
47
+ # model predicts one of the 1000 ImageNet classes
48
+ predicted_class = logits.argmax(-1).item()
49
  ```
50
 
51
  Currently, both the feature extractor and model support PyTorch. Tensorflow and JAX/FLAX are coming soon, and the API of ViTFeatureExtractor might change.