I’ve been using the HuggingFace library for quite sometime now. I go by the tutorials, swap the tutorial data with my project data and get very good results. I wanted to dig into a little bit deeper into how the classification happens by BERT and BERT-based models. I’m not able to understand a key significant feature - the [CLS] token which is responsible for the actual classification. I hope smart people here could answer my questions because I’m unable to find them on my own.
When I searched for what the [CLS] token actually represent, most of the results indicate that “it is an aggregate representation of the sequence”. I can understand this part. Basically before BERT, people have used different techniques to represent documents ranging from averaging the word vectors of the document to computing document vectors using doc2vec. I can also understand that stacking a linear classification and feeding in the values for the [CLS] token (768 dim for a bert-base-uncased model), we can end up classifying the sequence.
Here are my questions:
Is my above understanding of the [CLS] token correct?
Why is it always the first token? Why not the second, third or last? Did the authors of the original BERT paper get it to be the first token by trial and error?
How exactly does it “learn” the representation of the sequence? I mean its basically trained in the same way as the other input tokens in the sequence, so what makes it special to represent the entire sequence? I couldn’t find any explanation to this question from either the paper or my search afterwards.
Is it at all possible to get back the original sequence using the [CLS] token (I think not but worth asking)?
I hope I can find some answers to these questions (or at least pointers to resources where I can find them). Please let me know if this is not correct place to post these questions and where I should post them.
I believe “first token” is selected arbitrarily / conveniently .
In practice, you can finetune a classification task using any tokens or “average of tokens” (GlobalPooling1D) .
Thanks for the reply. But isn’t the other tokens specific to a particular input token as opposed to the [CLS] token which doesn’t correspond to any input token? If that’s the case, how does it make sense to take finetune any token for our classification?
This is what is tripping me up. Is there no reasoning empirical or otherwise to create a token called [CLS] to be used as input for downstream classification tasks?
I may be wrong when I said any tokens would do . If you have time, maybe you can just make an experiment about that.
My intuition is that at first each other token may indeed represent each original input token. But if you finetune any of them (backpropagation), it can also perform as good as [CLS]. (Never actually tried it).
One thing from my experience on Kaggle NLP competitions, however, is that the use of GlobalPooling1D is not inferior to [CLS] .
Using it as the first token is not “special” or new. Earlier NLP approaches also often had a beginning of sentence (BOS) token or similar. You wouldn’t want the token to be in-between other tokens either. The reason for this is not so much for CLS itself but for the other tokens: positioning of tokens relative to each other is important as the position of a token in a sequence changes its value due to positional encoding. Linguistically you’d therefore want the sequence order as-is without any information floating in between.
AFAIK special tokens cannot be masked during pretraining. So it is always at the front and through attention its importance is learnt. Like other tokens - but without it ever having to be “predicted”. In the second pretraining task, next sentence prediction, its final hidden state serves as the input for classification.
Without fine-tuning? No. Also note that BERT only has an encoder. To “generate” the original tokens, you’d typically need a decoder. You could try something like an auto-encoder, or set-up an encoder-decoder similar to single-representation MT. But chances are small that you can reproduce exactly the same input sentence.
Perhaps interesting to read into is Table 7 of their paper. They did not only try using the final state of CLS in downstream tasks, but also different feature extractions across the model.
OK so I am gonna try to explain it the best I can. Say I give you a relatively long passage (512 words) and ask you to summarize it. You’re gonna do that by reading it and picking out some words that convey the gist of the passage.
In other words, your summary can be described (obviously with some oversimplification) as a weighted average of words where a few words will gain a lot of “attention” (the usage here is two-fold i.e. how we use it and also in the transformer context) and others not so much.
This is how I like to think of the [CLS] token: a weighted average of the words such that the representation of the whole sequence is captured.
Now the question arises why does [CLS] end up being this overall representation of the sequence? Well over here, the way I like to think about it is this: without fine-tuning the [CLS] token ought to be like a loose average.
However, it during the fine-tuning that the downstream task forces [CLS] weights (remember its fixed position so some transformer weights are practically dedicated to this position that can never be taken by any other token) to morph into being this optimal weighted average that the task at hand needs it to be (thanks to the good ol’ backprop).
So if you’re doing sentiment classification, the weights for adjective tokens describing the emotions end up being bigger. If you’re doing toxic tweets classfication, then abusive language words end up getting more attention and so on.
Is this all backed by research or am I telling you a nice story? Well it is the largely latter with some exceptions
Hope it still helps!
@lewtun You are damn good at explaining stuff, what do you think?
Hmmm may be it is confusing… But how do we know that there isn’t a set of weights that lead to a weighted average vector very close to the [CLS] token?
Finally obviously it is not a simple weighted average since each token’s vector representation is repeatedly influenced (in each encoder layer) leading the final representation (not to mention the passes through FC layers) to be obviously a bit more than a straightforward weighted average but IMHO the idea holds due to how attention is computed.
What you mention is true for all tokens, not only for CLS. The value of CLS is influenced by other tokens, just like other tokens are influenced by their context (attention).
But that is not what a weighted average is. CLS still has its own embedding token and starts from that. It is not only the result of operations on the other tokens.
Yeah I think you see it as CLS as just any other token but for me the fact that it has its own dedicated position and that it is there in every single pretraining/finetuning step at that position already makes it uniquely positioned to carry sentence-level semantics.
Would be nice to see others’ take I am the farthest thing from authority on this subject.
And I do not disagree with that. It is a special token, always in the same position similar to other BOS tokens are used. But when you say that the CLS is only the “weighted average” of other tokens, then that is simply not correct. Terminology is important here. A weighted average is something very specific, and the value for CLS is calculated in a much more intricate way, taking into account both its own embeddings (token/pos) as well as the context.
I had one question with respect to the CLS token.
It is said to represent the entire sequence.
Let’s say we have 2 input sentences for a fine-tuning task:
Hello, my name is ABC.
It’s a lovely weather today.
So, we have 2 sequences in this case.
A CLS token, as far as my understanding, will have one embedding vector only, irrespective of the sequence because CLS is just like a normal token, i.e., just like Hello, my, name, is, It, 's, and so on will have their own individual embeddings, CLS will have its own embedding. So how does one embedding vector (which seems to not change for different sequences), capture the sequence level understanding of different sequences?
Pardon me for any gap in my understanding. Thank you.
Yes, the initial embedding of [CSL] token is unique. But after the attention layers, the corresponding hidden vectors of [CSL] tokens become different between different input sequence. The final hidden vector of [CSL] token is what we use for classification or downstream task, not its initial embedding.
Your understanding of the [CLS] token is correct. It serves as an aggregate representation of the entire sequence for classification tasks.
It’s positioned first in the sequence to ensure consistent handling of sequence-level tasks. While the [CLS] token learns in the same way as other tokens, it’s designed to capture the overall sequence information.
However, it cannot reconstruct the original sequence as it functions as a summary token.
I have few followup
Does it have to be the first token , can it always be the last token ? in that case its position embedding can change but the embedding would still be learnt ?
Every other token in the sentence also attends to cls token embedding as well ?