{ "cells": [ { "cell_type": "markdown", "id": "73785fe3", "metadata": {}, "source": [ "This notebook is a code implementation of https://magazine.sebastianraschka.com/p/understanding-and-coding-self-attention." ] }, { "cell_type": "code", "execution_count": 12, "id": "2dc68e3f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'2.6.0+cu124'" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "torch.__version__" ] }, { "cell_type": "markdown", "id": "13df2b7f", "metadata": {}, "source": [ "# Embedding an Input Sentence\n", "For simplicity, here our dictionary dc is restricted to the words that occur in the input sentence. In a real-world application, we would consider all words in the training dataset (typical vocabulary sizes range between 30k to 50k entries).\n", "\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "df7dbe06", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sentence = \"Life is short eat dessert first\"\n", "vocabulary = {s:i for i, s in enumerate(sorted(sentence.split()))}\n", "vocabulary" ] }, { "cell_type": "code", "execution_count": 14, "id": "6a2f6d1f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0, 4, 5, 2, 1, 3])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sentence_integer_encoding = torch.tensor([vocabulary[word] for word in sentence.split()])\n", "sentence_integer_encoding" ] }, { "cell_type": "markdown", "id": "669eb596", "metadata": {}, "source": [ "Now, using the integer-vector representation of the input sentence, we can use an embedding layer to encode the inputs into a real-vector embedding. Here, we will use a tiny 3-dimensional embedding such that each input word is represented by a 3-dimensional vector.\n", "\n", "Note that embedding sizes typically range from hundreds to thousands of dimensions. For instance, Llama 2 utilizes embedding sizes of 4,096. The reason we use 3-dimensional embeddings here is purely for illustration purposes. This allows us to examine the individual vectors without filling the entire page with numbers.\n", "\n", "Since the sentence consists of 6 words, this will result in a 6×3-dimensional embedding:" ] }, { "cell_type": "code", "execution_count": 15, "id": "9ca383b0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([6, 3]),\n", " tensor([[ 0.3374, -0.1778, -0.3035],\n", " [ 0.1794, 1.8951, 0.4954],\n", " [ 0.2692, -0.0770, -1.0205],\n", " [-0.2196, -0.3792, 0.7671],\n", " [-0.5880, 0.3486, 0.6603],\n", " [-1.1925, 0.6984, -1.4097]]))" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vocab_size = 50000\n", "torch.manual_seed(123)\n", "embed = torch.nn.Embedding(vocab_size, 3)\n", "embedded_sentence = embed(sentence_integer_encoding).detach()\n", "embedded_sentence.shape, embedded_sentence" ] }, { "cell_type": "markdown", "id": "88aa6eee", "metadata": {}, "source": [ "# Defining the Weight Matrices\n", "Now, let's discuss the widely utilized self-attention mechanism known as the scaled dot-product attention, which is an integral part of the transformer architecture.\n", "\n", "Self-attention utilizes three weight matrices, referred to as Wq, Wk, and Wv, which are adjusted as model parameters during training. These matrices serve to project the inputs into query, key, and value components of the sequence, respectively.\n", "\n", "The respective query, key and value sequences are obtained via matrix multiplication between the weight matrices W and the embedded inputs x:\n", "\n", "Query sequence: q(i) = x(i)Wq for i in sequence 1 … T\n", "\n", "Key sequence: k(i) = x(i)Wk for i in sequence 1 … T\n", "\n", "Value sequence: v(i) = x(i)Wv for i in sequence 1 … T\n", "\n", "Here, both q(i) and k(i) are vectors of dimension dk. The projection matrices Wq and Wk have a shape of d × dk , while Wv has the shape d × dv .\n", "\n", "(It's important to note that d represents the size of each word vector, x.)\n", "\n", "Since we are computing the dot-product between the query and key vectors, these two vectors have to contain the same number of elements (dq = dk). In many LLMs, we use the same size for the value vectors such that dq = dk = dv. However, the number of elements in the value vector v(i), which determines the size of the resulting context vector, can be arbitrary.\n", "\n", "So, for the following code walkthrough, we will set dq = dk = 2 and use dv = 4, initializing the projection matrices as follows:\n", "\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "312fdf10", "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(123)\n", "\n", "d = embedded_sentence.shape[1]\n", "d_q, d_k, d_v = 2, 2, 4\n", "\n", "W_query = torch.nn.Parameter(torch.rand(d, d_q))\n", "W_key = torch.nn.Parameter(torch.rand(d, d_k))\n", "W_value = torch.nn.Parameter(torch.rand(d, d_v))" ] }, { "cell_type": "markdown", "id": "4b56007e", "metadata": {}, "source": [ "# Computing the Unnormalized Attention Weights\n", "Now, let's suppose we are interested in computing the attention vector for the second input element -- the second input element acts as the query here:\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "dceaf3b4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([2])\n", "torch.Size([2])\n", "torch.Size([4])\n" ] } ], "source": [ "x_2 = embedded_sentence[1]\n", "query_2 = x_2 @ W_query\n", "key_2 = x_2 @ W_key\n", "value_2 = x_2 @ W_value\n", "\n", "print(query_2.shape)\n", "print(key_2.shape)\n", "print(value_2.shape)" ] }, { "cell_type": "code", "execution_count": 18, "id": "5c322369", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([6, 2]) torch.Size([6, 4])\n" ] } ], "source": [ "keys = embedded_sentence @ W_key\n", "values = embedded_sentence @ W_value\n", "print(keys.shape, values.shape)" ] }, { "cell_type": "markdown", "id": "fb114cd7", "metadata": {}, "source": [ "All keys and queries have been created, so now we can move on to computing the unnormalised attention weights. From here, the query is the second element. Next, we will compute the attention weight between the query and 5th element (how much attention element 2 should pay to the value of element 5)." ] }, { "cell_type": "code", "execution_count": 20, "id": "9de7819a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(1.2903, grad_fn=)\n" ] } ], "source": [ "omega_24 = query_2.dot(keys[4])\n", "print(omega_24)" ] }, { "cell_type": "code", "execution_count": null, "id": "f4442320", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([-0.6004, 3.4707, -1.5023, 0.4991, 1.2903, -1.3374],\n", " grad_fn=)\n" ] } ], "source": [ "omega_2 = query_2 @ keys.T\n", "# transpose keys to matrix multiply correctly\n", "# this is the logits.\n", "print(omega_2)" ] }, { "cell_type": "code", "execution_count": 24, "id": "7101391d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229],\n", " grad_fn=)\n" ] } ], "source": [ "import torch.nn.functional as F\n", "# We scale by 1 / sqrt(d_k) to keep the weight vectors roughly the same magnitude so they\n", "# don't get too big or small. Then softmax the values to turn them into probs.\n", "attention_weights_2 = F.softmax(omega_2 / d_k ** 0.5, dim=0)\n", "print(attention_weights_2)" ] }, { "cell_type": "code", "execution_count": null, "id": "d8163d6d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 5 }