Module 6 Lab: Attention patterns

Contents

Module 6 Lab: Attention patterns#

Visualize a scaled dot-product attention matrix for a short sequence.

Run the setup cell, inspect the printed diagnostics, and then complete the exercises at the end. The lab is intentionally small enough to run in GitHub Codespaces without a GPU.

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(16)
tokens = torch.tensor([[1.0, 0.0], [0.8, 0.2], [0.0, 1.0], [0.1, 0.9]])
Q = tokens @ torch.tensor([[1.0, 0.3], [0.2, 1.0]])
K = tokens @ torch.tensor([[1.0, -0.2], [0.1, 1.0]])
V = tokens
weights = F.softmax(Q @ K.T / (Q.size(-1) ** 0.5), dim=-1)
context = weights @ V
print("attention weights:")
print(weights.round(decimals=3))
print("context vectors:")
print(context.round(decimals=3))
plt.figure(figsize=(3, 3))
plt.imshow(weights.detach(), vmin=0, vmax=1)
plt.colorbar()
plt.close()
attention weights:
tensor([[0.3010, 0.2790, 0.2060, 0.2140],
        [0.2710, 0.2630, 0.2310, 0.2350],
        [0.1630, 0.1890, 0.3360, 0.3120],
        [0.1750, 0.1980, 0.3230, 0.3040]])
context vectors:
tensor([[0.5460, 0.4540],
        [0.5050, 0.4950],
        [0.3450, 0.6550],
        [0.3640, 0.6360]])

Lab exercises#

  1. Change one model or data parameter and rerun the lab.

  2. Record whether the metric improved, worsened, or stayed roughly the same.

  3. Add one sentence connecting the result to Attention and transformers.

  4. Identify one limitation of this toy setup before applying the idea to a real dataset.

# Reflection workspace
observation = ""
next_experiment = ""
print({"observation": observation, "next_experiment": next_experiment})
{'observation': '', 'next_experiment': ''}