
Hi everyone, if you’ve been following my previous blog posts, you’ll know that I’ve been utilizing Transformers in my recent reinforcement learning (RL) agents (Playing Chess With Offline Reinforcement Learning, Playing Chess With a Generalized AI). Recently a new type of Transformer called the Retrieval Transformer has been growing in popularity because of its efficiency. Furthermore, models like OpenAI‘s WebGPT and DeepMind‘s RETRO Transformer have proven that these smaller models can perform on par with large models such as GPT-3. Because of these promising results, it seems logical that we would keep up with recent trends and try and incorporate this type of model in our RL agent.
Retrieval Transformers
Before incorporating the Retrieval Transformer in our RL agent, let’s wrap our heads around what a Retrieval Transformer is and why it’s so effective by looking into DeepMinds RETRO Transformer. Below is a quote from a blog post by DeepMind where they describe the reasoning behind the methods used in the RETRO Transformer.
"Inspired by how the brain relies on dedicated memory mechanisms when learning, RETRO efficiently queries for passages of text to improve its predictions." [10]
This quote alludes to how our brain uses memories as a framework when making decisions by building on past experiences. For example, when you throw something up in the air, you expect it to hit the ground since your brain remembers the rule;
"what goes up must come down." – Isaac Newton [12]
DeepMind recreates this memory framework idea by exploiting examples stored in an external database that closely resemble each other. To determine what’s most similar to the input data, the RETRO Transformer determines the k-nearest neighbours (KNN) of the embeddings of l-1 chunks of the input data.

The neighbours and input data are encoded using separate self-attention encoder blocks and combined through a chunked cross-attention decoder block giving it an encoder-decoder architecture similar to the vanilla transformer.
Chunked cross-attention is a form of cross-attention where we split the Query (Q), Key (K) and Value (V) into smaller chunks and then perform cross-attention on each of these chunks. In chunked cross-attention, we do not perform cross-attention between each input chunk and its neighbour chunk.

Instead, we shift the input data by m-1 and create l-1 new input chunks. These newly created input chunks contain the last token from the preceding original input chunks and the following m-1 tokens of its original input chunk.

We then prepend the initially discarded m-1 tokens to the cross-attention outputs. By prepending the m-1 tokens, we retain more information from the original input data, which is critical since we want the model’s predictions to be greater influenced by the input data than the KNN.
Finally, we perform cross-attention between the remaining token in our final chunk and the last neighbour.

The resultant of the chunked cross-attention gets passed through a feed-forward layer.

Building Our Retrieval Transformer
Now that we understand the Retrieval Transformer better, we can effectively utilize its components in our model.
First, we’ll build the chunked cross-attention model. Since we’ve already created an attention model in a previous blog post, we can re-use this here. The main thing we need to do here is to create a chunking mechanism to pass our neighbour and input into our attention model.
Now that we built the chunked cross-attention model, we’ll create our KNN function. This function determines neighbours from a potentially sizable sample space, making loops time-consuming (O(n)). To optimize this function, we will vectorize it to remove the need to loop through every data point. First, we’ll convert our dataframe to a tensor which would allow the use of matrix calculations. Matrix calculations have shown performance boosts compared to things like pandas applying functions.

Below is the code for the matrix calculation version of our KNN function.
The model will look different with the newly created KNN function and chunked cross-attention layer. Below is a diagram of our new model architecture.

Looking at the diagram above, you’ll notice we added the new chunked cross-attention layer after our backbone layer. This placement was deliberate since our KNN function performs a lookup on a single matrix. Another benefit to placing our chunked cross-attention layer after the backbone layer is the amount of data represented in our backbone layer encodings. These encodings are the most data-rich in our model, allowing the found neighbours to supply a better framework for the final layers.
For our final step, we need to build the pipeline for the external embedding database.

The pipeline’s run offline since we only need to recalculate the embeddings when there is an update to the trunk of our model. Below is the code for this pipeline.
Thanks
And there you have it, we have successfully upgraded our chess AI to utilize components of the Retrieval Transformer. You can check a full version of the code on my GitHub here.
Thanks for reading. If you liked this, consider subscribing to my account to be notified of my most recent posts.
Reference
- https://arxiv.org/abs/1706.03762 [1]
- https://en.wikipedia.org/wiki/Reinforcement_learning [2]
- https://towardsdatascience.com/playing-chess-with-a-generalized-ai-b83d64ac71fe [3]
- https://towardsdatascience.com/building-a-chess-engine-part2-db4784e843d5 [4]
- https://openai.com/ [5]
- https://arxiv.org/abs/2112.09332 [6]
- https://deepmind.com/ [7]
- https://arxiv.org/abs/2112.04426 [8]
- https://arxiv.org/abs/2005.14165 [9]
- https://deepmind.com/blog/article/language-modelling-at-scale [10]
- https://bokcenter.harvard.edu/how-memory-works [11]
- https://www.goodreads.com/quotes/433926-what-goes-up-must-come-down [12]
- https://www.frieze.com/article/perception-vision#:~:text=In%20fact%2C%20it%20is%20now,brain%3B%20it%20comes%20from%20it.&text=In%20many%20ways%2C%20this%20makes%20sense. [13]






