From Memory networks to Transformers

The attention mechanism is probably the most important idea in modern NLP. Even though it’s almost omnipresent in all-new architectures it is still poorly understood. It’s not because attention is a complex concept but because we, as NLP researchers really don’t know where it comes from and why it is so effective. In this article, I try to find a meaningful historical background for attention and how it is evolved to the new modern form we all know today. I believe the best way to understand a concept is the main motivations behind how it’s made and that only possible if we know the geology behind it. The story starts, no surprise, with word vectors!

(main ideas for this comes from this amazing video)

Bag of words or vectors?!

The seminal paper of Mikolov [1] on the representation of words as vectors in the euclidean space was so successful that let to solving a lot of simple yet old NLP tasks like NER or POS tagging. Even though the idea was not totally new the mindset that it brought to the NLP community was unique: You can represent language items (at least words) in terms of vectors that are ready to be manipulated by all sorts of mathematical tricks. Word2vec and all its successors had the amazing capability to represent the words in a way that carries a lot of semantic information about them. The following famous example shows this:

here we don’t want to spend more time on explaining all different word vectors but instead, we want to go back to the type of questions researchers started to ask from themselves:

If a word can be represented using a vector of a certain size, is it possible to do the same for a sentence or even a paragraph?

Let’s try to solve this. The most obvious yet simplistic approach to solve this problem is by adding up all the word vectors in a sentence. This is exactly treating the sentence as a bag of words (or bag of vectors). Even though adding up the word vectors seems stupid or destructive, in practice, it works surprisingly well! if we consider each word vector as v_i with dimension d, the bag of vectors will be:

In many applications like text mining or search engines we really don’t care about order. But in more challenging tasks like translation or language comprehension, the order is crucial.

RNN

Not all NLP tasks can be solved with the bag of words model. This is the reason summation alone could not be the best solution especially when you can have RNN. RNN is a better solution because it tries to model the “final” representation of a sentence not as the summation but a more complex computation over the input words. Details aside the whole idea is to update the hidden state based on the current word (token) and previous hidden state.

More specifically, if we have an input sequence of word vectors as x_0, x_1, …, x_n, and the output sequence y_0=x_1, y_1=x_2, …

then we will have a sequence of h_0, h_1, ..h_n that h_i represent the whole sequence up until i’th token. Then the update rules are (at training time):

Then W, U, and V are trainable parameters. And sigma is a nonlinearity.

This approach is much better than summation but still suffers from different problems. First of all the nature of step by step sequential updating causes the model to either ignore the effects of the initial tokens or estimate it more than needed. Also, the causal nature of it doesn’t let the hidden state at step i to see the future hidden states.

Bahdanau and Loung came up with the idea of using attention. In the attention mechanism, we don’t use just one hidden state from the encoder in the decoder but calculate it for every decoding step. We get the hidden state corresponding to the current decoding step (query) and calculate the attention of that with encoder hidden states (keys) and finally get a weighted sum of those attention scores with again encoder hidden states (values). here because we set keys=values. This results in c_t that can be used alongside the current hidden state q_{t-1} to get the new hidden state q_t. q_t is enough to make a sample and create real tokens.

There was another trend in NLP that ignores the recurrent nature and just use attention.

Memory Networks

The attention mechanism was doing so well that some people started to ask if we even need the recurrent part in RNN+Attention. This sounds like we are going backward but just doing so is the first step to understand the attention mechanism alone without RNN hidden state involvement.

Imagine we want to calculate the hidden state corresponding to a sentence. We can start with some q as the hidden state (we discuss about q shortly). If we don’t care about RNN we simply calculate a weighted sum (one more degree of complexity compared to vanilla summation). And the weights show how much they are similar to q. More specifically, these weights are just probabilities.

and the probabilities are calculated by:

sometimes this is called soft attention because we have a smooth weight on all the input vectors. The hard attention is when we only put attention on one input and everyone else is zero.

In the case of seq2seq models, the initial hidden state comes from the decoder and in the encoder, we try to calculate how much the input sequence attends to that hidden state. For a start that hidden state corresponds to <s> which is the beginning of the sentence that we are decoding. as we move forward the hidden state gets updated.

For now, forget about the decoder hidden state. Instead, think about another example that the hidden state is a question (or query). Also, the vectors are not in one sentence but a collection of sentences. But for each sentence, we just add up their vectors to have one vector representation. To make it more grounded think about a question answering task. We have a set of sentences and a question. And we want to find the answer to that question based on the input sentences. For example:

In this example, we look for the answer to the question: “where is the milk?”

each m_i is the simple summation of the word vectors for that sentence. In other words:

and the same for the question:

attention mechanism in its core is just a multiplication, so if we try to update h to find the answer we have to follow more than one step. In the first step we find all the sentences that have milk which are m4 and m2. None of them has a “place”. But after the first update (by weighted sum) we have a representation that has both. and we know they both have “Joe” in common so that word becomes reinforced. Next time it looks for the places that “Joe” has been. It’s not always easy to explain what the network does. and because of this reason, we have to add more features to make this work. For example, time should be added to the m_i ‘s, because it’s the time sequence for answering is essential, so:

The steps can be replaced with “hops”. So the algorithm will be:

This is what we know as Memory Networks but according to one of its creators, it’s a very bad name and should be called Attention Network. In the schematic view, it looks like:

In practice and in a more realistic implementation we need to add more complexity. We need to featurize the input (before calculating attention) and output (after calculating attention) vectors separately. We call them in_memory and out_memory matrices respectively. This also helps up to understand the transformers at the end.

M_A and M_B are Sxd matrix, q is 1xd vector and p is 1xS vector.

here the phi functions consist of adding up the vectors in each sentence and a linear transformation. In the original paper, they used A for in_memrory and C for out_memory transformations.

In the following code snippet, you can see a very simple implementation of this idea in Pytorch. (for position embedding details see the paper). The whole code is here.

Transformer

Memory networks are a powerful framework that can be used for retrieving and question answering on a corpus of data. But what made the next generation of attention networks more powerful was adding a few novel features that made it very powerful. Improvements are (in order of significance):

1- Multi hidden-state

2- Multi-head attention

3- Residual blocks

The first idea is actually very crucial because unlike Memory networks that you only have one hidden state that is manipulated throughout the process we have “one hidden state per each token”. To understand what is happening we have to reformulate the problem a little bit. We only consider one m (one memory) to see how each one of the tokens change the hidden states. In other words, for q_j as the input hidden state corresponding to the token j of the question (or query)

K is Txd matrix, q_j, and k_i are 1xd and p_j is 1xT vector.

in the algorithm k_i and q_i are word vectors (we used different names for clarity rather than using r_i for all the word vectors). we have to perform the computations not only for j but also for all the j \in {1,.., T}, or the length of the question (query).

From here the derivation of the original Transformer is straightforward. Note that in the above algorithm we update q because we want to use that to answer the question like memory network, but if we can also update the k_i because we want to update the hidden representation of each token in memory given the question (query) like this:

Now if we want to make this more compact to one matrix multiplication solution we just have to put all q_i’s in one Q:

K is Txd matrix, Q is a Txd matrix

When Q=K then we call this self-attention because the query and key are the same. We can use the idea of featurization in_memory network and pass query, and keys through linear projection then K will be considered as in_memory, and then we have V as out_memory in the next line. In other words, we have:

K is Txd matrix, Q is a Txd matrix, V is Txd matrix W^K, W^V and W^Q are all dxd matrices

By now you probably understood the hops are just layers in the original Transformer architecture. The core part can be summarized as:

which is exactly what we have in the paper.

The next idea is multi-head attention which is basically different linear projection to lower dimension and then concatenating all of them to higher dimension again.

W^Q, W^K and W^V are d x (d/L) (make the original matrices smaller by a factor of L on their vector dimension)and W^O is dxd (no effect on dimension)

F is the concatenation function, we can write it with the following notation for more simplicity:

|| is the concatenation operator

And finally, the idea of residuals is important. We want to keep the information from the original item embedding through all the transformations. This is very important especially at the beginning of the training and helps the algorithm to become stable faster.

The implementation of self-attention is like the following:

In this diagram, the input in light blue will be projected to 3d dimension to create Q,K and V. This is equal to multiplication to W^Q, Q^K, and V^K respectively. Then we project each one of Q, K, and V to lower dimensions. In the real implementation, this can be done with a three dimensional tensor like the cubic shape. but for clarity, we decompose them to see what happens. The multiplication happens and then we calculate the attentions using softmax and the multiplication with V and finally all the parts concatenated to shape the final transformed representation which exactly has the same shape as the input. This exactly looks like the hugginface implementation.

This way we can look at the original transformer model this way

When you think about the architecture as Query, Keys, and Values then this looks not very different from the RNN+Attn. The only difference is that the encode and decoder in RNN+Attn is RNN style but here even the encoder and decoder themselves are attention-based. Note that I used a different schema compared to the original paper for the decoder. In that paper, decoder contains the Attention part too.

The mathematics behind the self-attention

(this section is written based on this blog post)

Associative memories are one of the earliest artificial neural models dating back to the 1960s and 1970s. Best known are Hopfield Networks, presented by John Hopfield in 1982. As the name suggests, the main purpose of associative memory networks is to associate an input with its most similar pattern. In other words, the purpose is to store and retrieve patterns.

Every Hopfield network has N stored patterns (keys) and other matrices and we want to find the most relevant item in those patterns for a state pattern (query).

stored patterns (keys)

An energy function is assigned to any Hopfield network. We take one state pattern (query) as the start point and follow an update rule to update the state pattern. This is very similar to memory networks above with updating the query. This updating process minimizes the energy function E. We stop updating when the energy is stable. In other words, the answer is the fixed point of the Energy function.

In this case, the query can be an incomplete part of one of the stored patterns (keys). For example, retrieving the correct version of the following picture having only a part of it.

The energy function in modern Hopfield networks can be described as:

In which F is a polynomial function. But it has been extended to the exponential function.

Based on this energy function and doing some math, we can derive the updating rule:

This is very similar to the attention formula. With the following changes we get to the attention formula:

  1. generalizing the new update rule to multiple patterns at once,
  2. mapping the patterns to an associative space,
  3. projecting the result

For S state patterns:

If we consider X^T as the N raw stored patterns Y = (y_1, …, y_N)^T, and map both stored patterns and state patterns. We will have:

We have:

Transposing both sides:

By projecting K with another matrix W_V we get the final equation:

In the case of an image, it would be like:

credit: https://ml-jku.github.io/hopfield-layers/

[1] Mikolov, T., Sutskever, I., Chen, K., Corrado, G.S., and Dean, J., 2013. Distributed representations of words and phrases and their compositionality. In Advances in neural information processing systems (pp. 3111–3119).

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store