Word Embedding in Transformer

Transformer, GAN, Diffusion, RNN, CNN...


mldev
Site Admin
Posts: 63
Joined: Sun Apr 17, 2022 4:42 pm

Word Embedding in Transformer

Post by mldev »

BERT is an encoder-only transformer model which is often used as a feature extractor. There are two types of word embeddings in BERT.

Static Embeddings(aka Token Embeddings)

The static embeddings are obtained directly from the embedding layer of the BERT model. These embeddings are generally closer in the vector space because they represent isolated words without any context. The static embeddings are learned at pre-training stage. It is fixed after the model is trained.

Dynamic Embeddings(aka Contextualized Embeddings)

The dynamic embeddings are computed from the final layer of the BERT model, taking the entire sentence (context) into account. BERT modifies these embeddings based on the surrounding words, leading to significant variations depending on the context. This can cause large distances between embeddings when the context changes drastically.

Consider the following example:

Code: Select all

# Words to analyze
words = ['king', 'queen', 'man', 'woman', 'apple']

# Context sentences
# format: word => [context 1 sentence, context 2 sentence]
contexts = {
    'king': ["The king is wise.", "The king and queen rule the kingdom."],
    'queen': ["The queen is kind.", "The queen is one of the great bands in history."],
    'man': ["The man is strong.", "The man and woman are friends."],
    'woman': ["The woman is smart.", "The woman and man are friends."],
    'apple': ["The king eats apple every day.", "How much is an Apple music account?"]
}

In the example, the same word in context 1 and 2 could mean different things. Queen is the wife of king in context 1, but in context 2 it means the music band. Apple is a fruit in context 1 and a company in context 2.

We draw two heatmap graphs on similarity and embedding distance below.

bert_word_embedding_simi.png
bert_word_embedding_simi.png (58.67 KiB) Viewed 2891 times

In static word embedding graph, we can see queen and king have high similarity score, which is 0.65. In context 1, it is still high, 0.72. However, in context 2, where queen means a band, its similarity with king drops to a low point 0.45.

bert_word_embedding_diff.png
bert_word_embedding_diff.png (49.45 KiB) Viewed 2891 times

In the distance graph above, we can see distance between queen and king is almost the same as distance between woman and man in static embedding distance graph. We can write it as:

Code: Select all

Static_Emb(queen) - Static_Emb(king) ~= Static_Emb(woman) - Static_Emb(man) (~=0.04)

This holds for context 1, where queen is the wife of king. However in context 2, the equation above doesn't hold anymore. The distance between man and woman is 7.4 but the distance between king and queen is 14. It is because queen is a music band, so the model stretch it away to the original position, so much farther from king in the vector space.