CuTe DSL Brings Custom NVIDIA GPU Kernels Into JAX
Katja Sirazitdinova presents CuTe DSL for JAX as an escape hatch for developers who need more GPU-kernel control than XLA provides, without leaving the JAX workflow. The tutorial argues that custom NVIDIA kernels can be written in Python with CUTLASS CuTe DSL, bridged into compiled JAX programs through `cutlass_call`, and used where fusion, special layouts, custom data movement, or narrow performance bottlenecks justify the extra discipline around shapes, dtypes, launch constraints, and static parameters.

The reason to leave the default JAX path
Katja Sirazitdinova starts from a narrow premise: JAX on NVIDIA GPUs is already fast, but some workloads eventually need control XLA will not automatically provide. Her examples are a fused operation XLA will not generate, a custom memory layout, a non-standard primitive, or a kernel aimed at a specific performance bottleneck.
Her answer is not to leave JAX. It is to write a GPU kernel in Python with CUTLASS CuTe DSL, then call it from JAX as if it were a native JAX operation. The workflow moves from a custom kernel, to execution inside @jax.jit, to Ahead-of-Time compilation with jax.export, where the compiled artifact can be serialized and reused without retracing in the same Python process.
The stack has three parts: CUTLASS supplies NVIDIA’s performance primitives for high-performance GPU kernels, especially GEMM and tensor operations; CuTe supplies the tensor and layout abstraction, treating a tensor as data plus layout, or pointer plus layout; JAX supplies composability with the larger compiled program. CuTe lets indexing be expressed in tensor coordinates while the layout handles address math.
The practical point is that CuTe kernels are still CUDA kernels. Threads and blocks are real. The developer reads threadIdx and blockIdx, controls launch shapes, and decides how data maps onto the GPU execution model. CuTe DSL separates that work into two layers: @cute.kernel defines the per-thread program, while @cute.jit defines the launcher, including grid, block, and CUDA stream.
That stream detail is central to the JAX integration. The launcher takes a CUDA stream as its first argument, and Sirazitdinova says that stream is managed by XLA. The custom kernel therefore runs on the same execution timeline as the rest of the JAX program, rather than as an out-of-band CUDA launch.
The reusable pattern: launcher, custom call, shape contract
Vector addition is the minimal version of the integration contract. In the CuTe kernel, each thread loads one element from A and one from B into registers, adds them, and stores the result. The code uses thread and block indexes, register fragments, and CuTe copy helpers to move data between global memory and registers.
The launcher chooses the number of threads per block and the number of blocks in the grid. In the JAX wrapper, the key step is wrapping that launcher with cutlass.jax.cutlass_call. That produces a callable that can execute as a JAX custom call inside @jax.jit.
cutlass_call bridges worlds.
On one side, there is a CuTe launcher that expects a CUDA stream and tensors. On the other, there is a JAX function that must compose inside XLA. cutlass_call turns the launcher into a JAX primitive or custom call, so it can sit inside a larger jax.jit function, combine with other JAX operations, and be scheduled by XLA as part of the compiled program.
That bridge imposes two recurring responsibilities on the developer. First, the JAX data must be presented in the layout the kernel expects. In the vector-add example, a one-dimensional vector is padded to a multiple of the block size, reshaped into a three-dimensional view — elements per thread, threads per block, number of blocks — passed to the kernel, then reshaped back and sliced to remove padding. Sirazitdinova describes that reshape as a layout reinterpretation, not a data copy.
Second, JAX must know the output shape and dtype up front so XLA can compile the graph. The reusable pattern is simple but strict: match the kernel’s layout expectations, and make output shape and dtype explicit.
Parameters, activations, and where fusion pays
SAXPY extends the same pattern with a scalar parameter: out = alpha * X + Y. The padding, reshaping, custom call, and reshape-back steps are the same as vector add. The new question is whether alpha is dynamic or static at compile time.
For this example, Katja Sirazitdinova treats alpha as static. It is marked static in jax.jit and passed into cutlass_call as a keyword argument. Her stated reason is that a static argument is known at trace time and can be baked into the compiled code path. She frames this as a useful pattern for configuration parameters or constants where a developer does not want dynamic dispatch cost.
The ReLU example switches indexing style. Instead of a three-dimensional elements-per-thread layout, it uses flat one-dimensional indexing: compute a global linear index from block and thread indexes, bounds-check it against N, load the value, apply max(x, 0), and store. The wrapper flattens the input, calls the kernel with N, and reshapes the output back to the original shape.
The fused bias-plus-ReLU example is where the custom-kernel case becomes more concrete for machine-learning workloads. The unfused computation is z = x + bias followed by y = ReLU(z) = max(0, z). Done separately, Sirazitdinova says it typically costs two kernel launches plus an intermediate write to global memory and a read back. The fused kernel loads x, loads the bias for the relevant column, adds, applies ReLU, and stores once.
So the message here isn't 'CuTe is a replacement for every JAX op.' The message is, when you have a hotspot that is memory traffic dominated or launch overhead dominated, fusion is a huge lever.
The wrapper also passes width as a static parameter because it affects indexing and is often useful as a compile-time constant.
| Example | What it demonstrates | JAX-side pattern |
|---|---|---|
| Vector add | End-to-end custom kernel called from JAX | Pad, reshape to kernel layout, call, reshape back |
| SAXPY | Scalar parameters such as alpha | Treat alpha as static and pass it into cutlass_call |
| ReLU | Flat CUDA-like indexing | Flatten input, pass N, reshape output |
| Fused bias + ReLU | Fusion to reduce launches and memory traffic | Pass width as a static indexing parameter |
GEMM is for mechanics, not beating cuBLAS
The tiled GEMM example is explicitly about mechanics, not replacement. Katja Sirazitdinova says the goal is not to beat cuBLAS as called under the hood by jnp.matmul; the simple implementation will not. The point is to show how to choose tile sizes, map blocks to output tiles, and pass problem sizes such as M, N, and K into the launcher from JAX.
The wrapper flattens A and B, provides M, N, and K, calls the custom kernel, and reshapes the flat output back to (M, N). That keeps the example aligned with the rest of the workflow: the custom operation is useful because it teaches launch configuration and integration, not because every hand-written kernel should replace a library primitive.
The same integration path is presented as preserving JAX composability, including multi-GPU sharding. A cutlass_call kernel can participate in JAX sharding: JAX shards arrays across devices, each device runs the same custom call on its local shard, and JAX handles coordination.
The example creates a device mesh, defines sharding specs, and uses jax.shard_map to run vector add across multiple GPUs without changing the kernel code itself. Sirazitdinova presents that as a key advantage of integrating through JAX rather than launching CUDA kernels outside the compiled JAX execution path.
Exporting a JAX program that contains a CuTe kernel
Ahead-of-Time compilation is the final step in the workflow. Normally, Katja Sirazitdinova says, a jax.jit function compiles in the current Python process. That is useful for research and iteration, but production often wants a different workflow: compile once, serialize an artifact, ship or cache it, and load it later without retracing. That is where jax.export enters.
The example function combines a CUTLASS custom call — an element-wise add kernel — with a regular JAX operation, jax.nn.sigmoid. A pure JAX reference implementation computes sigmoid(a + b). The exported version includes the custom call, so the code explicitly disables the relevant export safety check using CUTLASS’s helper.
Sirazitdinova identifies two details: the exported computation includes custom calls, and shapes matter. With concrete shapes, the serialized artifact is locked to those dimensions. With symbolic shapes, the same artifact can be reused across multiple input sizes, within the kernel’s launch constraints. The symbolic-shape example shows reuse for shapes where M * N is a multiple of the kernel’s block size, 256, and tests several shapes against the reference implementation.
The closing guidance is operational. Be explicit about output shape and dtype when building custom calls. Decide early whether a parameter should be dynamic or static; scalars such as alpha and configuration values such as width are often best treated as static. Watch padding and launch constraints: if a kernel assumes full blocks, pad inputs and slice the output. And use the approach where it fits — fusion, special layouts, custom data movement, and kernels standard libraries do not provide.


