|
--- |
|
license: apache-2.0 |
|
tags: |
|
- image-captioning |
|
languages: |
|
- en |
|
datasets: |
|
- michelecafagna26/hl |
|
language: |
|
- en |
|
metrics: |
|
- sacrebleu |
|
- rouge |
|
library_name: transformers |
|
--- |
|
## ClipCap fine-tuned for Rationale Image Captioning |
|
|
|
[ClipCap](https://arxiv.org/abs/2111.09734) base trained on the [HL Dataset](https://huggingface.co/datasets/michelecafagna26/hl) for **high-level rationale descriptions generation** |
|
|
|
## Model fine-tuning ποΈβ |
|
|
|
We fine-tune LM + Mapping Network starting from the model pretrained on COCO |
|
|
|
- Trained for 8 epochs |
|
- lr: 5eβ5 |
|
- Adam optimizer |
|
- half-precision (fp16) |
|
|
|
## Test set metrics π§Ύ |
|
|
|
| Cider | SacreBLEU | Rouge-L| |
|
|---------|------------|--------| |
|
| 78.04 | 11.71 | 25.76 | |
|
|
|
## Demo |
|
|
|
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1191fsBtOW1p_Qy17lVpr-2TaFqL24acE?usp=sharing) |
|
|
|
## Installation |
|
|
|
```bash |
|
pip install git+https://github.com/michelecafagna26/CLIPCap.git |
|
``` |
|
|
|
## Download the model |
|
|
|
```bash |
|
git lfs install # if not installed |
|
git clone https://huggingface.co/michelecafagna26/clipcap-base-captioning-ft-hl-rationales |
|
``` |
|
|
|
## Model in Action π |
|
|
|
|
|
```python |
|
from clipcap import ClipCaptionModel |
|
from transformers import ( |
|
GPT2Tokenizer, |
|
GPT2LMHeadModel, |
|
) |
|
import torch |
|
import clip |
|
import requests |
|
from PIL import Image |
|
|
|
model_path = "clipcap-base-captioning-ft-hl-rationales/pytorch_model.pt" # change accordingly |
|
|
|
# load clip |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False) |
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
prefix_length = 10 |
|
|
|
# load ClipCap |
|
model = ClipCaptionModel(prefix_length, tokenizer=tokenizer) |
|
model.from_pretrained(model_path) |
|
model = model.eval() |
|
model = model.to(device) |
|
|
|
# load the image |
|
img_url = '/static-proxy?url=https%3A%2F%2Fdatasets-server.huggingface.co%2Fassets%2Fmichelecafagna26%2Fhl%2F--%2Fdefault%2Ftrain%2F0%2Fimage%2Fimage.jpg%26%23x27%3B%3C%2Fspan%3E%3C!-- HTML_TAG_END --> |
|
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB') |
|
|
|
|
|
# extract the prefix |
|
image = preprocess(raw_image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
prefix = clip_model.encode_image(image).to( |
|
device, dtype=torch.float32 |
|
) |
|
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) |
|
|
|
# generate the caption |
|
model.generate_beam(embed=prefix_embed)[0] |
|
|
|
|
|
# >> "she is posing for a photo." |
|
``` |
|
|
|
## BibTex and citation info |
|
|
|
```BibTeX |
|
@inproceedings{cafagna2023hl, |
|
title={{HL} {D}ataset: {V}isually-grounded {D}escription of {S}cenes, {A}ctions and |
|
{R}ationales}, |
|
author={Cafagna, Michele and van Deemter, Kees and Gatt, Albert}, |
|
booktitle={Proceedings of the 16th International Natural Language Generation Conference (INLG'23)}, |
|
address = {Prague, Czech Republic}, |
|
year={2023} |
|
} |
|
``` |