|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") |
|
|
|
|
|
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") |
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
|
|
max_length = 16 |
|
num_beams = 4 |
|
num_return_sequences = 3 |
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams} |
|
|
|
|
|
|
|
|
|
def predict_step(image): |
|
i_image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
pixel_values = feature_extractor(images=i_image, return_tensors="pt").pixel_values |
|
pixel_values = pixel_values.to(device) |
|
|
|
output_ids = model.generate(pixel_values, **gen_kwargs, num_return_sequences=num_return_sequences) |
|
|
|
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
preds = [pred.strip() for pred in preds] |
|
return tuple(preds) |
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
inputs = gr.inputs.Image(type='numpy', label='Original Image') |
|
outputs = [gr.outputs.Textbox(label=f'Caption {i+1}') for i in range(num_return_sequences)] |
|
|
|
title = "Image Captioning using ViT + GPT2" |
|
description = "Image caption is generated for the uploaded image using ViT and GPT2. For training, COCO Dataset was utilised. If you see any biases (gender, race, etc.) in our picture captioning model that we were unable to identify during our stress tests, please use the 'Flag' button to mark the offending image." |
|
article = " <a href='https://huggingface.co/sachin/vit2distilgpt2'>Model Repo on Hugging Face Model Hub</a>" |
|
examples = [["test1.png"], ["test2.png"],["test3.png"]] |
|
|
|
gr.Interface( |
|
predict_step, |
|
inputs, |
|
outputs, |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=examples, |
|
theme="huggingface", |
|
).launch(debug=True, enable_queue=True) |
|
iface.launch() |