• The granularity of Triton is at the block level

    • i.e. in CUDA, your kernel defines exactly what each thread should do, and the block-size is an abstraction to allow the GPU to schedule blocks efficiently
    • in Triton, your kernel defines what a group of threads (you don’t know how many, or to be precise you don’t worry too much about it) should do with a given quantity of data (usually a block of data of shape BLOCK_SIZE)
  • In a triton kernel, you receive a pointer to the first element of your input vectors

    • This is because the GPU itself doesn’t know the implicit tensor layout of the tensor it receives from PyTorch
    • This is why you need to be comfortable with Tensor Layouts and why you will pass the tensor strides as arguments too.

Vector-addition kernel

import torch
import triton
import triton.language as tl
 
@triton.jit
def add_kernel(x_ptr,    # Pointers to first input vector.
               y_ptr,    # Pointers to second input vector.
               output_ptr, # Pointers to output vector.
               n_elements, # Size of the vector.
               BLOCK_SIZE: tl.constexpr): # Number of elements each program should process.
    # There are multiple 'programs' processing different data. We identify which program we are here:
    pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64], 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)