A Transformer Block in CUDA
December, somewhere over the Pacific on the way to Tokyo. The letGPU challenge open in one browser tab, Ro Salaverry’s The Scaling Era: An Oral History of AI, 2019–2025 on my Kindle. No plan beyond getting a feel for the challenge — I poked at it for a while, read a few chapters, and fell asleep somewhere east of the dateline. Woke up on approach to Narita.
The book stuck with me, though. It is a collection of firsthand accounts from the people who built the systems that now run most of the software industry, and the recurring theme is the same across a dozen voices: the transformer architecture arrived, and then everything else followed from scale.
That is what pulled me back to the challenge once I had a hotel and a power outlet. Not because I do this for a living — machine learning is not my area — but because the underlying question is squarely in territory I care about: what does this look like in hardware, and why does it work the way it does?
A previous post covered the GPT-2 decoder block at a conceptual level: what the LetGPU challenge is asking for, how data flows through the block, what each parameter group does, and why the structure is what it is. This post assumes that context — either you have read it, or you already know GPT-2 well enough that you do not need it. Here we go deeper: the actual CUDA implementation, kernel by kernel, with the concrete hardware trade-offs that make each design decision interesting. The code implements a single transformer block at GPT-2 small scale. No PyTorch, no cuBLAS — just kernels.
The Block, Concretely
The implementation is GPT-2 small: hidden dimension D=768, twelve attention
heads (H=12), head dimension Dh=64, feed-forward dimension FF=3072. One
forward pass through a block takes a token sequence of shape (seq_len, 768)
and produces the same shape.
This is a Pre-LN architecture — layer norm applied before each sublayer, not after. The original paper used Post-LN (norm after the residual add). The difference matters: Pre-LN feeds normalized inputs into the attention and FFN, which keeps gradient flow cleaner in deeper models. GPT-2 and most subsequent work moved to Pre-LN; the original formulation is now mostly a historical artifact.
The ten steps of the forward pass:
x ──► LN1 ──► QKV proj ──► MHA ──► O proj ──► add(x, ·) ──► x1
x1 ──► LN2 ──► FC proj ──► GELU ──► out proj ──► add(x1, ·) ──► output
In kernel terms:
layernorm768_kernel(x)→ln1matmul_bias_tiled(ln1, W_qkv)→qkv— shape(seq_len, 2304), packing Q, K, Vmha_nocausal_kernel(qkv)→attn— shape(seq_len, 768)matmul_bias_tiled(attn, W_attn)→projadd_kernel(x, proj)→x1— first residuallayernorm768_kernel(x1)→ln2matmul_bias_tiled(ln2, W_fc)→ff1— shape(seq_len, 3072)gelu_kernel(ff1)— in-placematmul_bias_tiled(ff1, W_proj)→ff2— shape(seq_len, 768)add_kernel(x1, ff2)→output— second residual
The rest of this post walks each step in turn.
Block Reduction
Almost every kernel needs a single value derived from all threads in a block — a sum for layer norm, a maximum for softmax numerical stability. The standard tool is a tree reduction over shared memory:
__device__ __forceinline__ float block_reduce_sum(float v) {
__shared__ float smem[1024];
int tid = threadIdx.x;
smem[tid] = v;
__syncthreads();
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride) smem[tid] += smem[tid + stride];
__syncthreads();
}
return smem[0];
}
Each thread writes its local value to shared memory, then the loop halves the
active thread count on each iteration: threads 0–511 add from threads 512–1023,
then 0–255 add from 256–511, and so on. After log₂(blockDim.x) iterations,
smem[0] holds the total. Every thread in the block reads that result.
The cost is the __syncthreads() calls — a block-wide barrier that stalls
until every thread in the block reaches it. Here there are log₂(256) = eight
barriers per reduction, and layer norm alone runs two reductions. Not free, but
unavoidable: the alternative is letting threads read values their neighbors
have not written yet.
block_reduce_max is identical with fmaxf in place of addition.
Layer Normalization
One block per token. All threads in the block cooperate on a single 768-element normalization:
// Pass 1: mean
float sum = 0.0f;
for (int i = threadIdx.x; i < 768; i += blockDim.x) sum += x[i];
sum = block_reduce_sum(sum);
float mean = sum / 768.0f;
// Pass 2: variance, then normalize
float vsum = 0.0f;
for (int i = threadIdx.x; i < 768; i += blockDim.x) {
float d = x[i] - mean; vsum += d * d;
}
vsum = block_reduce_sum(vsum);
float inv_std = rsqrtf(vsum / 768.0f + 1e-5f);
for (int i = threadIdx.x; i < 768; i += blockDim.x)
y[i] = (x[i] - mean) * inv_std * gamma[i] + beta[i];
rsqrtf is the reciprocal square root — a single hardware instruction on any
modern GPU, cheaper than computing sqrtf and then dividing. The epsilon
1e-5f keeps the denominator away from zero. gamma and beta are the
learned affine parameters applied after normalization.
Tiled Matrix Multiply
The projection steps — QKV, output, FFN up and down — all go through one templated kernel. The core idea is tiling: loading a 16×16 submatrix of A and a 16×16 submatrix of B into shared memory before computing their product, so the 256 threads in the block see each value once from slow global memory and reuse it sixteen times from fast shared memory.
template<int TILE>
__global__ void matmul_bias_tiled(const float* __restrict__ A,
const float* __restrict__ B,
const float* __restrict__ bias,
float* __restrict__ C,
int M, int K, int N) {
__shared__ float As[TILE][TILE];
__shared__ float Bs[TILE][TILE];
float acc = 0.0f;
for (int k0 = 0; k0 < K; k0 += TILE) {
As[threadIdx.y][threadIdx.x] = /* A[row, k0+threadIdx.x] */;
Bs[threadIdx.y][threadIdx.x] = /* B[k0+threadIdx.y, col] */;
__syncthreads();
#pragma unroll
for (int k = 0; k < TILE; ++k)
acc += As[threadIdx.y][k] * Bs[k][threadIdx.x];
__syncthreads();
}
C[row * N + col] = acc + bias[col];
}
The thread at (threadIdx.y, threadIdx.x) is responsible for one output
element at (row, col). It accumulates the dot product of a row of A with a
column of B, advancing through K in strides of TILE=16. The two
__syncthreads() per iteration are load-fence and store-fence: the first
ensures all threads have finished loading into shared memory before any thread
starts reading from it; the second ensures the compute is done before the next
iteration overwrites the tiles.
#pragma unroll on the inner loop tells the compiler to unroll the 16
iterations at compile time, generating 16 independent FMAs rather than a loop
with a branch and counter update on each pass.
__restrict__ on the pointer arguments is a promise to the compiler that A, B,
and C do not alias — the compiler can assume no write to C affects what is read
from A or B, enabling more aggressive load scheduling.
Multi-Head Attention
This is the interesting one. One CUDA block per (token, head) pair — the
launch grid is (seq_len, 12). Each block computes one row of the attention
output for one head.
The attention score for query token t against key token s in head h is:
score(t, s) = dot(Q_t_h, K_s_h) / sqrt(64)
Then softmax across all s, then weighted sum of the value vectors V_s_h.
Softmax has a numerical problem: if scores are large, exp(score) overflows.
The standard fix is to subtract the maximum before exponentiating. That
requires knowing the maximum, which requires a pass over all scores. So the
kernel runs three passes:
// Pass 1: compute raw scores, find max
float local_max = -1e30f;
for (int s = threadIdx.x; s < seq_len; s += blockDim.x) {
// compute dot(Q[t,h], K[s,h]) and scale
scores[s] = dot * inv_sqrt_dh;
local_max = fmaxf(local_max, scores[s]);
}
float max_sc = block_reduce_max(local_max);
// Pass 2: exp(score - max), accumulate sum
float local_sum = 0.0f;
for (int s = threadIdx.x; s < seq_len; s += blockDim.x) {
float e = expf(scores[s] - max_sc);
scores[s] = e; // reuse the shared buffer
local_sum += e;
}
float inv_sum = 1.0f / block_reduce_sum(local_sum);
// Pass 3: weighted sum of V (first 64 threads only)
if (threadIdx.x < Dh) {
float acc = 0.0f;
for (int s = 0; s < seq_len; ++s)
acc += scores[s] * inv_sum * V[s][h][threadIdx.x];
attn_out[t * D + h * Dh + threadIdx.x] = acc;
}
The scores buffer lives in dynamic shared memory (extern __shared__),
allocated at launch time as seq_len * sizeof(float). The size is not known
at compile time; the kernel receives it as the third argument to the <<<>>>
launch syntax: mha_nocausal_kernel<<<grid, block, seq_len * sizeof(float)>>>.
Note that shared memory is limited per block (48 KB on most devices, up to 96 KB
with explicit opt-in). For large seq_len this allocation will exceed the limit —
the implementation as written is suitable for the GPT-2 Small context lengths
(up to 1024 tokens → 4 KB), not for arbitrarily long sequences.
The QKV layout is interleaved by position: each token’s Q, K, and V are
contiguous. The head dimension is packed within each section. Pointer
arithmetic for Q at token t, head h is:
qkv + t * 2304 + 0 * 768 + h * 64, and similarly for K (offset 1*768)
and V (offset 2*768).
Pass 3 only uses the first 64 threads (one per head dimension). The remaining threads in the block are idle for that phase. This is a deliberate simplicity trade-off — for a challenge implementation it is fine; production kernels would use the idle threads to pipeline across heads or fuse operations.
GELU
The feed-forward sublayer uses GELU as its activation function. The true
definition involves the error function erf, which is expensive. The standard
approximation:
static __device__ __forceinline__ float gelu_tanh(float x) {
const float k = 0.7978845608028654f; // sqrt(2/pi)
float x3 = x * x * x;
return 0.5f * x * (1.0f + tanhf(k * (x + 0.044715f * x3)));
}
This is a Padé approximant to erf(x/sqrt(2)). The maximum absolute error
relative to the true GELU is small enough to be irrelevant for inference
(on the order of 10⁻⁴ in the active range — the exact bound depends on the
measurement interval, but PyTorch uses this same approximation in production).
tanhf is hardware-accelerated on every CUDA-capable device; the full
expression is five floating-point operations. The kernel applies it elementwise
across the (seq_len, 3072) intermediate tensor.
Putting It Together
The solve function allocates eight intermediate tensors on the device,
launches the kernels in sequence, and frees them. The weight layout is a flat
device buffer. Offsets are computed from the architecture constants:
| Weights | Offset | Size |
|---|---|---|
| γ₁, β₁ | 0 | 768 + 768 |
| W_qkv | 1,536 | 768 × 2,304 |
| b_qkv | 1,771,008 | 2,304 |
| W_attn | 1,773,312 | 768 × 768 |
| b_attn | 2,363,136 | 768 |
| γ₂, β₂ | 2,363,904 | 768 + 768 |
| W_fc | 2,365,440 | 768 × 3,072 |
| b_fc | 4,724,736 | 3,072 |
| W_proj | 4,727,808 | 3,072 × 768 |
| b_proj | 7,087,104 | 768 |
About 7.1 million parameters per block. GPT-2 small stacks twelve of them on top of embeddings — 85 million parameters total. The numbers in The Scaling Era that stuck with me were the orders of magnitude larger: GPT-3 at 175 billion, models that followed at trillions. The architecture is the same. The arithmetic just runs longer.
The implementation here is not production code. It is not fused, not tuned for tensor cores, and does three passes through memory where a flash attention implementation would do one. But the patterns — shared memory tiling, tree reductions, online softmax — are the same patterns production kernels use, without the extra complexity that comes from optimizing for throughput at scale.
Sometimes it helps to see the skeleton before studying the muscle.
Full Code
#include <cuda_runtime.h>
#include <math_constants.h>
#include <cmath>
#define CHECK_CUDA(x) do { cudaError_t err = (x); if (err != cudaSuccess) return; } while(0)
static __device__ __forceinline__ float gelu_tanh(float x) {
// GELU(x) = 0.5*x*(1 + tanh(sqrt(2/pi)*(x + 0.044715*x^3)))
const float k = 0.7978845608028654f; // sqrt(2/pi)
float x3 = x * x * x;
return 0.5f * x * (1.0f + tanhf(k * (x + 0.044715f * x3)));
}
// ---------------------------
// Reductions (block-wide)
// ---------------------------
__device__ __forceinline__ float block_reduce_sum(float v) {
// assumes blockDim.x <= 1024
__shared__ float smem[1024];
int tid = threadIdx.x;
smem[tid] = v;
__syncthreads();
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride) smem[tid] += smem[tid + stride];
__syncthreads();
}
return smem[0];
}
__device__ __forceinline__ float block_reduce_max(float v) {
__shared__ float smem[1024];
int tid = threadIdx.x;
smem[tid] = v;
__syncthreads();
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride) smem[tid] = fmaxf(smem[tid], smem[tid + stride]);
__syncthreads();
}
return smem[0];
}
// ---------------------------
// LayerNorm over 768 features
// One block per token
// ---------------------------
__global__ void layernorm768_kernel(const float* __restrict__ in,
float* __restrict__ out,
const float* __restrict__ gamma,
const float* __restrict__ beta,
int seq_len) {
int t = blockIdx.x;
if (t >= seq_len) return;
const float* x = in + t * 768;
float* y = out + t * 768;
// sum
float sum = 0.0f;
for (int i = threadIdx.x; i < 768; i += blockDim.x) sum += x[i];
sum = block_reduce_sum(sum);
float mean = sum / 768.0f;
// var
float vsum = 0.0f;
for (int i = threadIdx.x; i < 768; i += blockDim.x) {
float d = x[i] - mean;
vsum += d * d;
}
vsum = block_reduce_sum(vsum);
float var = vsum / 768.0f;
float inv_std = rsqrtf(var + 1e-5f);
for (int i = threadIdx.x; i < 768; i += blockDim.x) {
float n = (x[i] - mean) * inv_std;
y[i] = n * gamma[i] + beta[i];
}
}
// ---------------------------
// Simple tiled GEMM (row-major)
// C = A(M,K) * B(K,N) + bias(N)
// TILE=16; assumes N,K multiples of 16 for best perf (true for your shapes)
// ---------------------------
template<int TILE>
__global__ void matmul_bias_tiled(const float* __restrict__ A,
const float* __restrict__ B,
const float* __restrict__ bias,
float* __restrict__ C,
int M, int K, int N) {
int row = blockIdx.y * TILE + threadIdx.y;
int col = blockIdx.x * TILE + threadIdx.x;
__shared__ float As[TILE][TILE];
__shared__ float Bs[TILE][TILE];
float acc = 0.0f;
for (int k0 = 0; k0 < K; k0 += TILE) {
float a = 0.0f;
float b = 0.0f;
if (row < M && (k0 + threadIdx.x) < K)
a = A[row * K + (k0 + threadIdx.x)];
if (col < N && (k0 + threadIdx.y) < K)
b = B[(k0 + threadIdx.y) * N + col];
As[threadIdx.y][threadIdx.x] = a;
Bs[threadIdx.y][threadIdx.x] = b;
__syncthreads();
#pragma unroll
for (int k = 0; k < TILE; ++k)
acc += As[threadIdx.y][k] * Bs[k][threadIdx.x];
__syncthreads();
}
if (row < M && col < N) {
acc += (bias ? bias[col] : 0.0f);
C[row * N + col] = acc;
}
}
// ---------------------------
// Add residual: y = a + b
// ---------------------------
__global__ void add_kernel(const float* __restrict__ a,
const float* __restrict__ b,
float* __restrict__ y,
int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) y[i] = a[i] + b[i];
}
// ---------------------------
// GELU elementwise
// ---------------------------
__global__ void gelu_kernel(float* __restrict__ x, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) x[i] = gelu_tanh(x[i]);
}
// ---------------------------
// Attention kernel:
// One block per (token t, head h)
// Input: qkv [seq, 2304] = [Q|K|V] each 768
// Output: attn_out [seq, 768] (concatenated heads)
// No causal mask
//
// Shared memory: scores[seq_len] floats
// blockDim.x recommended 256
// ---------------------------
__global__ void mha_nocausal_kernel(const float* __restrict__ qkv,
float* __restrict__ attn_out,
int seq_len) {
int t = blockIdx.x; // query token
int h = blockIdx.y; // head 0..11
if (t >= seq_len || h >= 12) return;
extern __shared__ float shmem[]; // size >= seq_len floats
float* scores = shmem;
const int D = 768;
const int Dh = 64;
const float inv_sqrt_dh = 1.0f / sqrtf((float)Dh);
// Q pointer for this (t,h)
const float* Q = qkv + t * (3 * D) + 0 * D + h * Dh;
// Pass 1: compute scores[s] = dot(Q, K_s) / sqrt(Dh)
float local_max = -1e30f;
for (int s = threadIdx.x; s < seq_len; s += blockDim.x) {
const float* K = qkv + s * (3 * D) + 1 * D + h * Dh;
float dot = 0.0f;
#pragma unroll
for (int i = 0; i < Dh; ++i) dot += Q[i] * K[i];
float sc = dot * inv_sqrt_dh;
scores[s] = sc;
local_max = fmaxf(local_max, sc);
}
float max_sc = block_reduce_max(local_max);
__syncthreads();
// Pass 2: exp(scores - max), sum
float local_sum = 0.0f;
for (int s = threadIdx.x; s < seq_len; s += blockDim.x) {
float e = expf(scores[s] - max_sc);
scores[s] = e;
local_sum += e;
}
float sum_sc = block_reduce_sum(local_sum);
float inv_sum = 1.0f / sum_sc;
__syncthreads();
// Pass 3: weighted sum of V
if (threadIdx.x < Dh) {
int i = threadIdx.x;
float acc = 0.0f;
for (int s = 0; s < seq_len; ++s) {
float p = scores[s] * inv_sum;
const float* V = qkv + s * (3 * D) + 2 * D + h * Dh;
acc += p * V[i];
}
attn_out[t * D + h * Dh + i] = acc;
}
}
// ---------------------------
// Entry point (device pointers)
// x: (seq_len, 768)
// output: (seq_len, 768)
// weights packed per your offsets
// ---------------------------
extern "C" void solve(const float* x, float* output, const float* weights, int seq_len) {
const int D = 768;
const int FF = 3072;
const int H = 12;
const float* gamma1 = weights + 0;
const float* beta1 = weights + 768;
const float* W_qkv = weights + 1536;
const float* b_qkv = weights + 1771008;
const float* W_attn = weights + 1773312;
const float* b_attn = weights + 2363136;
const float* gamma2 = weights + 2363904;
const float* beta2 = weights + 2364672;
const float* W_fc = weights + 2365440;
const float* b_fc = weights + 4724736;
const float* W_proj = weights + 4727808;
const float* b_proj = weights + 7087104;
float *ln1 = nullptr, *qkv = nullptr, *attn = nullptr, *proj = nullptr, *x1 = nullptr;
float *ln2 = nullptr, *ff1 = nullptr, *ff2 = nullptr;
size_t bytes_ln = (size_t)seq_len * D * sizeof(float);
size_t bytes_qkv = (size_t)seq_len * (3 * D) * sizeof(float);
size_t bytes_ff1 = (size_t)seq_len * FF * sizeof(float);
CHECK_CUDA(cudaMalloc(&ln1, bytes_ln));
CHECK_CUDA(cudaMalloc(&qkv, bytes_qkv));
CHECK_CUDA(cudaMalloc(&attn, bytes_ln));
CHECK_CUDA(cudaMalloc(&proj, bytes_ln));
CHECK_CUDA(cudaMalloc(&x1, bytes_ln));
CHECK_CUDA(cudaMalloc(&ln2, bytes_ln));
CHECK_CUDA(cudaMalloc(&ff1, bytes_ff1));
CHECK_CUDA(cudaMalloc(&ff2, bytes_ln));
// 1) LN1
layernorm768_kernel<<<seq_len, 256>>>(x, ln1, gamma1, beta1, seq_len);
// 2) QKV = ln1 * W_qkv + b_qkv (M=seq_len, K=768, N=2304)
{
const int TILE = 16;
dim3 block(TILE, TILE);
dim3 grid((2304 + TILE-1)/TILE, (seq_len + TILE-1)/TILE);
matmul_bias_tiled<TILE><<<grid, block>>>(ln1, W_qkv, b_qkv, qkv, seq_len, 768, 2304);
}
// 3) MHA: attn = (seq_len, 768)
{
dim3 grid(seq_len, H);
size_t shmem = (size_t)seq_len * sizeof(float);
mha_nocausal_kernel<<<grid, 256, shmem>>>(qkv, attn, seq_len);
}
// 4) proj = attn * W_attn + b_attn
{
const int TILE = 16;
dim3 block(TILE, TILE);
dim3 grid((768 + TILE-1)/TILE, (seq_len + TILE-1)/TILE);
matmul_bias_tiled<TILE><<<grid, block>>>(attn, W_attn, b_attn, proj, seq_len, 768, 768);
}
// 5) x1 = x + proj
{
int n = seq_len * D;
add_kernel<<<(n+255)/256, 256>>>(x, proj, x1, n);
}
// 6) LN2
layernorm768_kernel<<<seq_len, 256>>>(x1, ln2, gamma2, beta2, seq_len);
// 7) ff1 = ln2 * W_fc + b_fc (M=seq_len, K=768, N=3072)
{
const int TILE = 16;
dim3 block(TILE, TILE);
dim3 grid((3072 + TILE-1)/TILE, (seq_len + TILE-1)/TILE);
matmul_bias_tiled<TILE><<<grid, block>>>(ln2, W_fc, b_fc, ff1, seq_len, 768, 3072);
}
// 8) GELU(ff1)
{
int n = seq_len * FF;
gelu_kernel<<<(n+255)/256, 256>>>(ff1, n);
}
// 9) ff2 = ff1 * W_proj + b_proj (M=seq_len, K=3072, N=768)
{
const int TILE = 16;
dim3 block(TILE, TILE);
dim3 grid((768 + TILE-1)/TILE, (seq_len + TILE-1)/TILE);
matmul_bias_tiled<TILE><<<grid, block>>>(ff1, W_proj, b_proj, ff2, seq_len, 3072, 768);
}
// 10) output = x1 + ff2
{
int n = seq_len * D;
add_kernel<<<(n+255)/256, 256>>>(x1, ff2, output, n);
}
cudaFree(ln1); cudaFree(qkv); cudaFree(attn); cudaFree(proj);
cudaFree(x1); cudaFree(ln2); cudaFree(ff1); cudaFree(ff2);
}