<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>kernel-programming on echen's blog</title><link>https://blog.echen.io/tags/kernel-programming/</link><description>Recent content in kernel-programming on echen's blog</description><generator>Hugo -- gohugo.io</generator><language>en-us</language><lastBuildDate>Mon, 18 May 2026 00:00:00 +0000</lastBuildDate><atom:link href="https://blog.echen.io/tags/kernel-programming/index.xml" rel="self" type="application/rss+xml"/><item><title>FlashAttention-2 in CuTe, from Scratch</title><link>https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/</link><pubDate>Mon, 18 May 2026 00:00:00 +0000</pubDate><guid>https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/</guid><description>&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/thumbnail.png" alt="Featured image of post FlashAttention-2 in CuTe, from Scratch" />&lt;p>Two days. That&amp;rsquo;s how long it took me to write FlashAttention-2 in Triton. I had never touched a GPU kernel before and somehow replicated a revolutionary algorithm between sips of coffee. I was looking to learn something new and somehow finished before I really got started. I knew the original source was written in CUDA, made by my Bay Area buddy who turned rocks into trillion-ade. Riding on my hubris from my recent success, I decided to translate my pièce de résistance to the native dialect of the rock-whisperers themselves.&lt;/p>
&lt;p>Three hundred cups of not-Java later, I raised the blood-soaked script in triumph. I spent many moons deciphering the barren scriptures and getting led by astray stochastic parrots and silicon oracles. That month was not a party with the rocks but a bad trip I barely came down from.&lt;/p>
&lt;p>Every kernel developer has to walk this path eventually. I&amp;rsquo;ve seen the fear in the eyes of the avoiders and the bodies along the way. This post is for whoever&amp;rsquo;s next.&lt;/p>
&lt;hr>
&lt;p>We&amp;rsquo;ll walk through FlashAttention-2 end-to-end on an A100, implemented in C++ CuTe: GMEM/SMEM async copies, tiled MMAs, swizzling, online softmax, and the epilogue store. We&amp;rsquo;ll cover every detail in the source code, every decision, and every throwaway line that ruined an evening.&lt;/p>
&lt;p>A quick note on what this blog is supposed to be &amp;ndash; CuTe&amp;rsquo;s documentation is a reference, not a tutorial. You can read it cover to cover and walk away still not knowing CuTe; it&amp;rsquo;s only really worth consulting once you&amp;rsquo;re already trying to do something specific and have hit a wall. This post attempts to be that something specific, and we&amp;rsquo;re going to run straight into those walls together. Most public CuTe writing covers one layer at a time: a layout here, a swizzle there, an MMA atom over there. Tying them into a real kernel is where the difficulty actually lies, and that&amp;rsquo;s the gap this post is trying to fill.&lt;/p>
&lt;p>The code we&amp;rsquo;ll walk through is a stripped-down mirror of Tri Dao&amp;rsquo;s production FA-2: same idioms, same building blocks, often the same lines, but with the causal/RoPE/dropout/KV-cache/QK-smem-sharing template branches removed. The core logic is visible instead of buried under config flags that bloat up the repo. Where it matters, our kernel reaches close to parity with the source &amp;ndash; on an A100, &lt;strong>88-105% of production FA-2&amp;rsquo;s throughput&lt;/strong> across hdim=64/128 and seq lengths up to 64K, peaking at 63% of fp16 tensor-core utilization (&lt;a class="link" href="https://github.com/cloudui/cuda-triton#cute-flashattention-2-cuda-batch4-heads8" target="_blank" rel="noopener"
>full benchmark table&lt;/a>).&lt;/p>
&lt;p>The point is not novelty &amp;ndash; it&amp;rsquo;s just to show that our simplifications do not break performance. We are rewriting one production case of many, albeit the simplest and least commonly used (LLMs are all non-causal). Where our code diverges, that&amp;rsquo;s usually because we found something &amp;ndash; an inconsistency, a copy-paste from a CuTe example that most take for granted, a one-line simplification, a choice that turns out to be critical for some non-trivial reason. These moments are flagged in-line throughout the post, and at least one of them (&lt;a class="link" href="#svtnoswizzle-the-no-op-nobody-caught" >the &lt;code>sVtNoSwizzle&lt;/code> line&lt;/a>) appears to be a no-op that nobody in the lineage of this code understood. Make of that what you will.&lt;/p>
&lt;p>What this isn&amp;rsquo;t: a summary of the FA-2 paper, a Hopper/Blackwell post (newer algorithms are meaningfully different on newer hardware), or a CuTe guide. This is Ampere-specific, code-level, and committed to the bit that we don&amp;rsquo;t move on from a line until we fully understand why it&amp;rsquo;s there.&lt;/p>
&lt;h1 id="flashattention-2">FlashAttention-2&lt;/h1>
&lt;p>If you&amp;rsquo;re reading this, I&amp;rsquo;ll assume you already have a solid understanding of the attention mechanism and at least the basics of the FlashAttention-2 algorithm itself. If not, I recommend reading the original flash attention paper&lt;sup id="fnref:1">&lt;a href="#fn:1" class="footnote-ref" role="doc-noteref">1&lt;/a>&lt;/sup> before coming back. Or, you could just read the article as-is, because you&amp;rsquo;ll probably piece it together through the struggle of trying to understand. It would be helpful to at least know the pseudocode/baseline algorithm for FA2, and even better if you&amp;rsquo;ve tried simulating it in PyTorch (or your framework of choice) or maybe even wrote it in Triton.&lt;/p>
&lt;p>If you&amp;rsquo;ve never touched CUDA, you should at least try to understand its SIMT programming nature and maybe implement a few basic kernels using this thread-level view. Try to build a solid understanding of how CUDA works and of NVIDIA&amp;rsquo;s GPU architecture, from threads to warps to thread blocks to SMs and beyond. I&amp;rsquo;ll talk about a lot of these concepts in detail, but I still assume a basic understanding of GPU or hardware paradigms. I will be as comprehensive as I can, but it will be an uphill battle should you try to read this blog in its entirety without &lt;em>some&lt;/em> background knowledge.&lt;/p>
&lt;p>Most of this blog concerns how high-level concepts like &amp;ldquo;online softmax&amp;rdquo; or &amp;ldquo;GEMM&amp;rdquo; actually translate to production-grade code. The algorithm itself is not particularly difficult in theory, but the implementation details at the CUDA level can become a nightmare, particularly for beginners. Tri Dao originally wrote FA2 using &lt;strong>CuTe&lt;/strong> (CUDA Templates), the layout-algebra core inside NVIDIA&amp;rsquo;s CUTLASS 3.x library &amp;ndash; see &lt;a class="link" href="#why-cute-and-whats-cutlass" >Why CuTe&lt;/a> for the philosophy behind it and how it differs from Triton, WMMA, and raw CUDA. The short version is that CuTe doesn&amp;rsquo;t abstract the hardware away &amp;ndash; it gives you the full-fat algebra to describe hardware-aware layouts exactly. Simply writing this in CuTe will force you to understand and optimize for the hardware.&lt;/p>
&lt;p>Since the release of Blackwell (B200), NVIDIA released CuTe&amp;rsquo;s Python DSL&amp;ndash;a Python library you can use to write the same code without all the annoying templating that comes baggaged with C++. The use case and methodology is pretty much unchanged, but debugging and templating become more palatable, and the compile times are enormously faster due to just-in-time (JIT) compilation. Moving forward, the CuTe 3.X in C++ we use today will probably be somewhat of a relic, but as a learning exercise, nothing beats the absolute struggle of working with the most annoying and explicit version of whatever you&amp;rsquo;re trying to learn.&lt;/p>
&lt;h1 id="overview">Overview&lt;/h1>
&lt;h2 id="design-choices">Design Choices&lt;/h2>
&lt;p>We&amp;rsquo;re going to make some basic design choices to make this learning exercise simpler on the implementation side. A lot of the source code involves edge cases and optional configuration settings (RoPE, QK smem sharing, etc.) that aren&amp;rsquo;t practical for learning the fundamentals of FA2 and CuTe. Our choices are as follows:&lt;/p>
&lt;ul>
&lt;li>A100: the GPU I had access to and the industry standard when FA2 was released. Newer architecture generations like Hopper (H100) and Blackwell (B200) have even more complicated algorithms (e.g. FA3, FA4) due to hardware improvements and optimizations.&lt;/li>
&lt;li>fp16: supported on A100 tensor cores, pretty basic default for most kernel ops during training
&lt;ul>
&lt;li>fp32 accumulation, reduces precision drift, more accurate FLOPs for softmax and scale&lt;/li>
&lt;/ul>
&lt;/li>
&lt;li>Clean basic out-of-the-box attention mechanism: no causal masking, RoPE, dropout, etc.&lt;/li>
&lt;li>head_dim: focus on 32, 64 and 128 block sizes&lt;/li>
&lt;li>Assume the sequence length is a power of two, or more specifically, a multiple of the Q block size.&lt;/li>
&lt;li>Expect $Q, K, V$ to be contiguous along &lt;code>head_dim&lt;/code>, i.e. all of shape &lt;code>(seqlen, head_dim)&lt;/code> in PyTorch/JAX.&lt;/li>
&lt;/ul>
&lt;h2 id="some-naming-conventions">Some Naming Conventions&lt;/h2>
&lt;p>If you look at the FA2 source code, you might notice they have some weird naming conventions. Some of them are standard CuTe/CUTLASS, some carry over from other things. Here are some patterns:&lt;/p>
&lt;ul>
&lt;li>Starts with k: compile-time constant, e.g. &lt;code>kBlockM&lt;/code>, &lt;code>kHeadDim&lt;/code>&lt;/li>
&lt;li>$M, N, K$: All of general matrix-multiply (GEMM) parameters are in this order for a $(M, K) \times (K, N)$ matrix-multiply. Hence, the shape of Q is &lt;code>(kBlockM, kHeadDim)&lt;/code> and the shape of K, V is &lt;code>(kBlockN, kHeadDim)&lt;/code>.&lt;/li>
&lt;li>&lt;code>kBlockKSmem&lt;/code>: width of the SMEM &amp;ldquo;row tile&amp;rdquo; used for the swizzle atom, separate from &lt;code>kHeadDim&lt;/code>. Capped at 64 so one &lt;code>Swizzle&amp;lt;3,3,3&amp;gt;&lt;/code> atom can serve every hdim that&amp;rsquo;s a multiple of 64. Covered in detail in &lt;a class="link" href="#swizzling-fa2" >Swizzling FA2&lt;/a>.&lt;/li>
&lt;li>&lt;strong>TN convention&lt;/strong> (for our tiled MMA atom &lt;code>SM80_16x8x16_F32F16F16F32_TN&lt;/code>): A is row-major in M-K and B is column-major in K-N from the matmul&amp;rsquo;s perspective. Practical effect: K is the contiguous dimension for both A and B, which matches our &lt;code>(seqlen, head_dim)&lt;/code> layout for Q, K.&lt;/li>
&lt;li>FA2 and CuTe have a weird but relatively consistent variable naming scheme for tensors. I&amp;rsquo;m just going to give one example to give you an idea: &lt;code>tSrQ&lt;/code>. t=thread, S=QK softmax result, r=registers, Q=query matrix. They use &lt;code>s&lt;/code> for SMEM and &lt;code>g&lt;/code> for GMEM. Non thread-owned tensors have no leading &lt;code>t&lt;/code>.
&lt;ul>
&lt;li>Technically S is post-softmax of $P=QK^T$, but FA2 consistently calls the intermediate accumulator S.&lt;/li>
&lt;/ul>
&lt;/li>
&lt;/ul>
&lt;h2 id="basic-structure">Basic Structure&lt;/h2>
&lt;p>First, attention itself:&lt;/p>
&lt;p>$$\begin{aligned}
P &amp;amp;= \frac{QK^T}{\sqrt{d_h}} \\
S &amp;amp;= \text{softmax}(P) \\
O &amp;amp;= S \cdot V
\end{aligned}$$&lt;/p>
&lt;p>Or in pytorch for those who haven&amp;rsquo;t read a math equation in a while:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">P&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Q&lt;/span>&lt;span class="nd">@K&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">T&lt;/span>&lt;span class="o">/&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sqrt&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">d_h&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">S&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">functional&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">softmax&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">P&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dim&lt;/span>&lt;span class="o">=-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">O&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">S&lt;/span>&lt;span class="nd">@V&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="high-level-details">High-Level Details&lt;/h3>
&lt;p>Let&amp;rsquo;s establish the specifics of the FA2 algorithm at a high level.&lt;/p>
&lt;ul>
&lt;li>Our grid is &lt;code>batch/head&lt;/code> x &lt;code>q_tile&lt;/code>. The batch/head dimensions are independent and can be grouped. The &lt;code>q_tile&lt;/code> determines which tile of Q we get, and we make it the last dimension for better cache locality between thread blocks.&lt;/li>
&lt;li>Q is of shape &lt;code>(kBlockM, kHeadDim)&lt;/code>. The main computation on any thread block revolves around the Q tile. This Q tile does not change for the duration of the thread block. We iterate over the relevant K, V tile per thread block to get an output tile. Each q tile maps to exactly the same size output tile, which is necessary as we need to manifest a whole row of P to do the softmax.&lt;/li>
&lt;li>We load each tile from global memory (GMEM) to shared memory (SMEM) for staging. When we need to do our GEMM, we load from SMEM to the register file as we loop over K and V.&lt;/li>
&lt;li>Our GMEM-&amp;gt;SMEM copying are all async (&lt;code>cp.async&lt;/code> on Ampere). Q technically doesn&amp;rsquo;t really have to be since it stays constant throughout the thread block. Its singular GMEM load doesn&amp;rsquo;t overlap that much compute but we make the optimization nonetheless.&lt;/li>
&lt;li>&lt;strong>Warps tile along M, not N.&lt;/strong> This is &lt;em>the&lt;/em> defining layout decision of the kernel. We arrange our tensor-core warps so that each warp owns entire rows of the $QK^T$ output, never a slice across the row. The reason is the online softmax: row max and row sum reductions stay &lt;em>inside a warp&lt;/em> and resolve via warp-shuffle primitives (&lt;code>__shfl_xor_sync&lt;/code>) &amp;ndash; no shared-memory staging, no &lt;code>__syncthreads&lt;/code>. This is a huge performance decision that turns softmax from a bottleneck to nearly free. Most subsequent design choices, such as &lt;code>Tiled_MMA&lt;/code> shape, fragment partitioning, &lt;code>Softmax&lt;/code> struct, and the rescale loop falls out of this one decision.&lt;/li>
&lt;/ul>
&lt;h3 id="the-kernel-outline">The Kernel Outline&lt;/h3>
&lt;ol>
&lt;li>Define GMEM, SMEM, register files, hardware copy/GEMM instructions, and mappings&lt;/li>
&lt;li>Load Q tile from global memory to SMEM. This is only done once, as Q tile doesn&amp;rsquo;t change.&lt;/li>
&lt;li>Prefetch 0th K-tile.&lt;/li>
&lt;li>Loop start: Wait for K-tile to arrive. Then, prefetch the next V-tile.&lt;/li>
&lt;li>Issue GEMM for $P = QK^T$ tile.&lt;/li>
&lt;li>Wait for V-tile to arrive. Then, issue next K-tile prefetch if we&amp;rsquo;re not on the last iteration.&lt;/li>
&lt;li>Compute $S=\text{softmax}(P)$ and softmax statistics and update accumulator/output tile.&lt;/li>
&lt;li>Issue GEMM for $O = SV$ tile.&lt;/li>
&lt;li>Loop back to 4 until row is complete.&lt;/li>
&lt;li>Scale final output by $l=\text{expsum(P)}$&lt;/li>
&lt;li>Copy output from SMEM back to GMEM. Ampere doesn&amp;rsquo;t have any direct SMEM-&amp;gt;GMEM instructions so we stage this copy through registers.&lt;/li>
&lt;/ol>
&lt;p>Only 11 steps and they&amp;rsquo;re all pretty simple in concept&amp;hellip;Let&amp;rsquo;s take a deeper look into the implementation details.&lt;/p>
&lt;h1 id="code-layout-the-repo">Code Layout: The Repo&lt;/h1>
&lt;p>All of the code in this post lives in &lt;a class="link" href="https://github.com/cloudui/cuda-triton" target="_blank" rel="noopener"
>&lt;code>@github:cloudui/cuda-triton&lt;/code>&lt;/a>. This repo contains every kernel in Triton and CUDA I&amp;rsquo;ve written along the way to writing FA2 in CuTe. The two-day Triton FA2 from the intro is in &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/kernels/flash_attention_full.py" target="_blank" rel="noopener"
>&lt;code>kernels/flash_attention_full.py&lt;/code>&lt;/a>; the intermediate WMMA CUDA FA2 (the bridge between Triton and CuTe) is in &lt;a class="link" href="https://github.com/cloudui/cuda-triton/tree/main/cuda/flash_attn" target="_blank" rel="noopener"
>&lt;code>cuda/flash_attn/&lt;/code>&lt;/a>; the &lt;a class="link" href="https://github.com/cloudui/cuda-triton/tree/main/cute" target="_blank" rel="noopener"
>&lt;code>cute/&lt;/code>&lt;/a> directory is an in-progress port of this kernel to NVIDIA&amp;rsquo;s new Python CuTe DSL. The post itself walks through &lt;a class="link" href="https://github.com/cloudui/cuda-triton/tree/main/cuda/flash_attn_cutlass" target="_blank" rel="noopener"
>&lt;code>cuda/flash_attn_cutlass/&lt;/code>&lt;/a> &amp;ndash; the C++ CuTe implementation. Everything has tests, benchmarks, and a top-level &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/README.md" target="_blank" rel="noopener"
>&lt;code>README.md&lt;/code>&lt;/a> you can browse that are mostly up-to-date.&lt;/p>
&lt;blockquote>
&lt;p>Please refer to Tri Dao&amp;rsquo;s repo as well. It contains much more edge case handling and templating that I do not touch on at all: &lt;a class="link" href="https://github.com/Dao-AILab/flash-attention/tree/main" target="_blank" rel="noopener"
>@github:Dao-AILab/flash-attention&lt;/a>.&lt;/p>
&lt;p>Specifically, their implementation lives in &lt;a class="link" href="https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src" target="_blank" rel="noopener"
>csrc/flash_attn&lt;/a>&lt;/p>
&lt;/blockquote>
&lt;p>Before we dive into the implementation, here&amp;rsquo;s the file map for the kernel we&amp;rsquo;re walking through. The snippets in the blog are simplified versions of what&amp;rsquo;s in the repo. I&amp;rsquo;ll cite specific files as we go, but the up-front overview is:&lt;/p>
&lt;table>
&lt;thead>
&lt;tr>
&lt;th>File&lt;/th>
&lt;th>What&amp;rsquo;s in it&lt;/th>
&lt;/tr>
&lt;/thead>
&lt;tbody>
&lt;tr>
&lt;td>&lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/flash.h" target="_blank" rel="noopener"
>&lt;code>flash.h&lt;/code>&lt;/a>&lt;/td>
&lt;td>The &lt;code>Flash_fwd_params&lt;/code> struct &amp;ndash; pointers, batch/head strides, sizes. The runtime side of the kernel API.&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td>&lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/kernel_traits.cuh" target="_blank" rel="noopener"
>&lt;code>kernel_traits.cuh&lt;/code>&lt;/a>&lt;/td>
&lt;td>Compile-time type composition: SMEM Layouts with &lt;code>Swizzle&lt;/code>, &lt;code>TiledMma&lt;/code>, &lt;code>Copy_Atom&lt;/code>s, and block sizes. Any &lt;code>using&lt;/code> C++ declarations in this blog live here.&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td>&lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/flash_fwd_kernel.h" target="_blank" rel="noopener"
>&lt;code>flash_fwd_kernel.h&lt;/code>&lt;/a>&lt;/td>
&lt;td>The kernel body: Q load, KV main loop, softmax, epilogue. This is where most of the blog&amp;rsquo;s code snippets come from.&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td>&lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/flash_fwd_launch_template.h" target="_blank" rel="noopener"
>&lt;code>flash_fwd_launch_template.h&lt;/code>&lt;/a>&lt;/td>
&lt;td>Grid/block sizing, &lt;code>cudaFuncSetAttribute&lt;/code> for extended SMEM, kernel launch.&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td>&lt;a class="link" href="https://github.com/cloudui/cuda-triton/tree/main/cuda/flash_attn_cutlass" target="_blank" rel="noopener"
>&lt;code>flash_fwd_hdim{32,64,128}_*.cu&lt;/code>&lt;/a>&lt;/td>
&lt;td>Per-config explicit instantiations so we don&amp;rsquo;t have to JIT.&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td>&lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/flash_api.cu" target="_blank" rel="noopener"
>&lt;code>flash_api.cu&lt;/code>&lt;/a>&lt;/td>
&lt;td>PyTorch extension entry point and runtime dispatch by &lt;code>head_dim&lt;/code>.&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td>&lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/softmax.cuh" target="_blank" rel="noopener"
>&lt;code>softmax.cuh&lt;/code>&lt;/a>&lt;/td>
&lt;td>The &lt;code>Softmax&amp;lt;kNRows&amp;gt;&lt;/code> struct, &lt;code>softmax_rescale_o&lt;/code>, warp/quad reductions.&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td>&lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/utils.cuh" target="_blank" rel="noopener"
>&lt;code>utils.cuh&lt;/code>&lt;/a>&lt;/td>
&lt;td>Layout-rewrite helpers (&lt;code>convert_layout_rowcol&lt;/code>, &lt;code>convert_layout_acc_Aregs&lt;/code>), &lt;code>convert_type&lt;/code>, copy helpers.&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td>&lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/setup.py" target="_blank" rel="noopener"
>&lt;code>setup.py&lt;/code>&lt;/a>&lt;/td>
&lt;td>Build script with the CUTLASS include path.&lt;/td>
&lt;/tr>
&lt;/tbody>
&lt;/table>
&lt;p>&lt;strong>Where the blog&amp;rsquo;s snippets live.&lt;/strong> Most of the &lt;code>cpp&lt;/code> blocks in the post are stripped-down versions of the production code:&lt;/p>
&lt;ul>
&lt;li>&lt;code>kernel_traits.cuh&lt;/code> &amp;ndash; everything in &lt;a class="link" href="#tiled-mma" >Tiled MMA&lt;/a>, &lt;a class="link" href="#copy-atoms" >Copy Atoms&lt;/a>, &lt;a class="link" href="#swizzling-fa2" >Swizzling FA2&lt;/a>, and the V-copy SMEM layouts.&lt;/li>
&lt;li>&lt;code>flash_fwd_kernel.h&lt;/code> &amp;ndash; everything in &lt;a class="link" href="#gmem-and-smem-tensors" >GMEM and SMEM Tensors&lt;/a>, &lt;a class="link" href="#q-k-smem-register-tiled-copy" >Q,K SMEM-&amp;gt;Register Tiled Copy&lt;/a>, &lt;a class="link" href="#mma-loop-qkt-gemm" >MMA Loop: QK^T GEMM&lt;/a>, &lt;a class="link" href="#the-actual-async-copy-strategy" >The Actual Async Copy Strategy&lt;/a>, and the &lt;a class="link" href="#epilogue-output-gmem" >Epilogue&lt;/a>.&lt;/li>
&lt;li>&lt;code>softmax.cuh&lt;/code> &amp;ndash; everything in &lt;a class="link" href="#online-softmax" >Online Softmax&lt;/a>.&lt;/li>
&lt;li>&lt;code>utils.cuh&lt;/code> &amp;ndash; the &lt;code>convert_layout_rowcol&lt;/code> reshape in &lt;a class="link" href="#fragment-reshape" >Fragment Reshape&lt;/a>.&lt;/li>
&lt;/ul>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: I often combine the &lt;code>kernel_traits&lt;/code> declarations and the &lt;code>flash_fwd_kernel&lt;/code> code to keep them in one block, and I sometimes leave out function declarations that wrap certain blocks of code for brevity. If you ever become confused, all important sections link to their source at the top of their subsections for reference.&lt;/p>
&lt;/blockquote>
&lt;p>&lt;strong>A note on parallel scratch.&lt;/strong> Alongside the production kernel, &lt;a class="link" href="https://github.com/cloudui/cuda-triton/tree/main/scratch" target="_blank" rel="noopener"
>&lt;code>scratch/&lt;/code>&lt;/a> contains the small standalone CuTe demos I wrote while losing my mind in confusion. The &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/README.md" target="_blank" rel="noopener"
>scratch README&lt;/a> maps each file to a blog section. If a concept ever feels too abstract on the page, run the corresponding scratch file and stare at its output for a while. It might help. The instructions to run are in the repo&amp;rsquo;s READMEs.&lt;/p>
&lt;p>&lt;strong>A note on simplification.&lt;/strong> As called out in the &lt;a class="link" href="#flashattention-but-the-actual-details" >intro&lt;/a>, this blog walks through a stripped-down mirror of &lt;a class="link" href="https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src" target="_blank" rel="noopener"
>Tri Dao&amp;rsquo;s production FA-2&lt;/a>. Where the source has branches for causal masking, RoPE, KV-cache, dropout, QK SMEM sharing, etc., this kernel doesn&amp;rsquo;t &amp;ndash; the load-bearing FA2 logic is what&amp;rsquo;s left. Wherever this kernel diverges from the source in a non-trivial way, I flag it inline.&lt;/p>
&lt;h1 id="cute-the-basics">CuTe, the Basics&lt;/h1>
&lt;p>As established in the intro, the CuTe docs are a reference; this section is the tutorial that doesn&amp;rsquo;t exist. I won&amp;rsquo;t cover all the APIs &amp;ndash; you can intuit 90% of them from context and the FA2 code. However, the concepts that turn your evenings into late nights will be waiting for you here. It&amp;rsquo;s hard to internalize the motivations for certain CuTe features until you&amp;rsquo;ve encountered the problem they&amp;rsquo;re meant to solve, so if a section feels abstract, skip ahead to the FA2 implementation and come back when you hit the wall it was written for.&lt;/p>
&lt;blockquote>
&lt;p>You can find the official docs at &lt;a class="link" href="https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html;" target="_blank" rel="noopener"
>https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html;&lt;/a> they&amp;rsquo;re worth keeping open in another tab.&lt;/p>
&lt;/blockquote>
&lt;h2 id="why-cute-and-whats-cutlass">Why CuTe (and what&amp;rsquo;s CUTLASS?)&lt;/h2>
&lt;p>&lt;strong>CUTLASS&lt;/strong> (probably writing out what it stands for takes about as long as this parenthetical)&lt;sup id="fnref:2">&lt;a href="#fn:2" class="footnote-ref" role="doc-noteref">2&lt;/a>&lt;/sup> is NVIDIA&amp;rsquo;s open-source library of GEMM and GEMM-adjacent building blocks &amp;ndash; templated kernels, atoms, copy primitives, etc. &lt;strong>CuTe&lt;/strong> is the layout-algebra core &lt;em>inside&lt;/em> CUTLASS introduced in version 3. CUTLASS 2.x was a different, more rigid GEMM-policy-composition framework; CUTLASS 3.x reorganized everything around CuTe, which is now the language we use to describe layouts, tile shapes, and thread-data mappings. &amp;ldquo;CuTe FA2&amp;rdquo; references FA2 written in the CUTLASS 3.x idiom, which is what the production source uses.&lt;/p>
&lt;p>Why use CuTe at all? There are other options for writing high-performance GPU kernels:&lt;/p>
&lt;ul>
&lt;li>&lt;strong>Triton:&lt;/strong> hides the hardware. You write &lt;code>tl.dot(q, k)&lt;/code> and Triton picks the MMA atom, swizzles SMEM, tiles for you, and the autotuner explores the configuration space. You think in tiles and Triton does all the hard work expressing it in PTX.&lt;/li>
&lt;li>&lt;strong>WMMA (&lt;code>nvcuda::wmma&lt;/code>):&lt;/strong> hides the fragment layout. You get opaque &lt;code>wmma::fragment&amp;lt;&amp;gt;&lt;/code> types and can&amp;rsquo;t easily reason about swizzling, LDSM behavior, or per-thread register state. It&amp;rsquo;s easy for getting to using the tensor cores quickly but not that useful when you need to express something WMMA doesn&amp;rsquo;t model.&lt;/li>
&lt;li>&lt;strong>Raw PTX/SASS:&lt;/strong> &lt;sup id="fnref:3">&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref">3&lt;/a>&lt;/sup> I mean, sure, but bless your soul.&lt;/li>
&lt;/ul>
&lt;p>CuTe is the philosophical opposite of Triton. While Triton abstracts the hardware away, CuTe makes you do all the hard work yourself, which requires you to understand exactly what the hardware wants. In Triton, you don&amp;rsquo;t have to think at all about what the SMEM access pattern should be to hit all 32 banks. CuTe makes you draw all those bit masks and blocks on paper so your performance doesn&amp;rsquo;t get crushed by the holy memory manager himself. Triton does the entire MMA in one &lt;code>tl.dot&lt;/code> call while CuTe forces you to choose your exact MMA atom, copy strategy, and concatenation dim. You can do a lot in Triton without understanding what&amp;rsquo;s happening deep down &amp;ndash; CuTe hides behind its name and forces you to drag your knees through the dirt.&lt;/p>
&lt;p>This is good for the same reason it&amp;rsquo;s painful: you can&amp;rsquo;t write CuTe well without understanding the hardware. The flip side is that learning CuTe forces you to understand the hardware because there&amp;rsquo;s no abstraction for you to hide behind. You have the power to squeeze the maximum performance out of your hardware, and nothing is hidden behind a veil of nice API calls. If you read this blog in its entirety and follow along with the code, you&amp;rsquo;ll come out the other side understanding the Ampere memory pipeline, swizzling math, MMA atoms, LDSM semantics, and register-fragment behavior. You&amp;rsquo;d never have to learn any of those to write FA2 in Triton.&lt;/p>
&lt;p>You are also trading understanding for time, because I wrote the Triton version in like two afternoons. It took me weeks to write and &amp;ldquo;understand&amp;rdquo; the first pass of the CuTe version, but only after spending almost 100 hours on this blog can I truly say I understand it. It&amp;rsquo;s the 80-20 rule but on steroids, you can probably get 80% of the performance in a tenth of the time.&lt;/p>
&lt;p>In practice, CuTe is essentially a templating engine that lets you manipulate memory using tensors, shapes, layouts, data types, and strides &amp;ndash; conceptually similar to PyTorch&amp;rsquo;s &lt;code>torch.Tensor&lt;/code> object, but much more granular and much more powerful. It lets you declare a general &amp;ldquo;shape&amp;rdquo; once and template it with fp32 vs fp16 by just passing different parameters. You&amp;rsquo;re still responsible for all the sizes. The code may extract fp16 from a 128-bit load, but you&amp;rsquo;ll have to figure out that 128 bits is 8 fp16 numbers. It just handles the typing on your behalf and lets you index things with nicer code. It certainly is not &amp;ldquo;easier,&amp;rdquo; and is often a nightmare to read. You&amp;rsquo;ll see why pretty soon.&lt;/p>
&lt;h2 id="layouts-shapes-and-strides">Layouts, Shapes, and Strides&lt;/h2>
&lt;blockquote>
&lt;p>&lt;strong>Play:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/01_layouts.cu" target="_blank" rel="noopener"
>&lt;code>scratch/01_layouts.cu&lt;/code>&lt;/a>&lt;/p>
&lt;/blockquote>
&lt;p>Ah yes, back to tensor school. Shape and stride are precisely the same concepts as in PyTorch. A layout is just a composition of a shape and a stride.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="cp">#include&lt;/span> &lt;span class="cpf">&amp;lt;cute/tensor.hpp&amp;gt;&lt;/span>&lt;span class="cp">
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span>&lt;span class="c1">// runnable just like this without GPU
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">layout&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_8&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">layout_1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">8&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{}),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_stride&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">print_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">layout&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// this is the shape of a torch.tensor([[0]*8 for _ in range(16)]).T
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// stdout
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">:&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="mi">0&lt;/span> &lt;span class="mi">1&lt;/span> &lt;span class="mi">2&lt;/span> &lt;span class="mi">3&lt;/span> &lt;span class="mi">4&lt;/span> &lt;span class="mi">5&lt;/span> &lt;span class="mi">6&lt;/span> &lt;span class="mi">7&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="o">+----+----+----+----+----+----+----+----+&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="mi">0&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">0&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">1&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">2&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">3&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">4&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">5&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">6&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">7&lt;/span> &lt;span class="o">|&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="o">+----+----+----+----+----+----+----+----+&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="mi">1&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">8&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">9&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">10&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">11&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">12&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">13&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">14&lt;/span> &lt;span class="o">|&lt;/span> &lt;span class="mi">15&lt;/span> &lt;span class="o">|&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="o">+----+----+----+----+----+----+----+----+&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>A shape of (2, 8) with stride (8, 1), and CuTe provides us with a nice &lt;code>print_layout&lt;/code> function to see the shape and indexing. Pretty simple. Both declarations are identical. So what&amp;rsquo;s with the freaky underscores?&lt;/p>
&lt;h2 id="statically-vs-dynamically-typed">Statically vs. Dynamically Typed&lt;/h2>
&lt;p>Any standard C++ integer passed into a layout, shape, or stride is dynamically typed, i.e. its value is only known at runtime (e.g. int, const int, static int). Even CUDA&amp;rsquo;s &lt;code>constexpr int&lt;/code> is treated as such by CuTe. Any time you index into a tensor, the library will compute&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">A&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">i&lt;/span>&lt;span class="p">][&lt;/span>&lt;span class="n">j&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">i&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="n">stride_row&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="n">j&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="n">stride_col&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Each index operation is a multiply-and-add, which can be quite costly. Instead, when we can, we opt to use statics: type wrappers used by CUTLASS to allow the value to be known at compile time. It&amp;rsquo;s just a C++ compiler trick that allows CuTe to compute all indexing during compilation rather than at runtime, saving the GPU from having to do so while it&amp;rsquo;s running. Obviously, you can only do this if the sizes are predetermined&amp;ndash;either definite, templated, or constant. So instead of passing in &lt;code>make_stride(2, 4)&lt;/code>, we can pass in &lt;code>make_stride(Int&amp;lt;2&amp;gt;{}, _4{})&lt;/code>. Functionally, these are the same, but any subsequent indexing will be done at compile time for the latter.&lt;/p>
&lt;p>Layouts &lt;em>do&lt;/em> take in dynamic integers as well. They should be used &lt;em>if they are only known at runtime&lt;/em>.&lt;/p>
&lt;p>Some syntax quirks:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// identical, CuTe provides most power of twos by default as shorthand
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">8&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">_8&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// Functions take in objects, types only use types
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">l1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_4&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_4&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">l2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">_4&lt;/span>&lt;span class="p">{}),&lt;/span> &lt;span class="n">make_stride&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">_4&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// type
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">8&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// object/struct
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">8&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// can have both dynamics and statics in same layout
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">shape&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_4&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">stride&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_stride&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">256&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">64&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Read more here:&lt;/p>
&lt;h2 id="tensors">Tensors&lt;/h2>
&lt;blockquote>
&lt;p>&lt;strong>Play:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/02_tensor.cu" target="_blank" rel="noopener"
>&lt;code>scratch/02_tensor.cu&lt;/code>&lt;/a>&lt;/p>
&lt;/blockquote>
&lt;p>A tensor is just a pointer wrapped in a layout. The underlying data is just a pointer, usually to contiguous data, and the layout determines how we interact with it. Pretty much exactly the same as a &lt;code>torch.Tensor&lt;/code>. The difference is that we manage the layout: we can change it to whatever we want, however we want, but we are ultimately responsible for the tensor&amp;rsquo;s integrity.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">static&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">[]&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="p">...,&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="p">};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">l_row_major&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">_4&lt;/span>&lt;span class="p">{}),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_stride&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_4&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// row-major view
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">t_row_major&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">l_row_major&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// column major view
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">l_col_major&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">_4&lt;/span>&lt;span class="p">{}),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_stride&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">_4&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">t_col_major&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">l_col_major&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// tensor indexing: we&amp;#39;ll cover the &amp;#34;why&amp;#34; next
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// i, j = 2, 3
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">n_row&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">t_row_major&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="c1">// 12
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">n_col&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">t_col_major&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="c1">// 15
&lt;/span>&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="registers-arent-memory">Registers Aren&amp;rsquo;t Memory&lt;/h2>
&lt;p>One critical clarification before we go further. CuTe gives you &lt;code>Tensor&lt;/code> objects backed by GMEM, SMEM, &lt;em>and&lt;/em> registers, and they all look identical in the code. You index them, you read layouts, and pass them to different functions. However, the register tensor is lying to you in a benign way: &lt;strong>registers are not addressable memory&lt;/strong>. They are hardcoded slots wired into the cores. There is no &amp;ldquo;register address.&amp;rdquo; The &amp;ldquo;layout&amp;rdquo; on a register tensor is purely a compiler-side mapping from a logical index (e.g. &lt;code>frag(0, 1)&lt;/code>) to a specific physical register name (e.g. &lt;code>%r17&lt;/code>). You should treat it like a 1-1 lookup table, not as something with strides you can do pointer arithmetic on.&lt;/p>
&lt;p>This matters because:&lt;/p>
&lt;ul>
&lt;li>A register tensor &amp;ldquo;stride&amp;rdquo; is a code abstraction. A &amp;ldquo;column-major&amp;rdquo; register fragment doesn&amp;rsquo;t have any physical column-major memory underneath it &amp;ndash; there is no memory underneath at all.&lt;/li>
&lt;li>You cannot vectorize across register layout the way you can across SMEM. Vectorization on registers depends on what values a thread holds and what hardware store/load instructions exist for that combination &amp;ndash; not what the layout looks like.&lt;/li>
&lt;li>A register tensor is implicitly &lt;em>per-thread&lt;/em>. Every thread in the warp has its own copy of the same &lt;code>Tensor&lt;/code> object referring to its own physical registers. There is no shared register pool like SMEM.&lt;/li>
&lt;/ul>
&lt;p>We&amp;rsquo;ll lean on this every time we touch fragments. If a fragment-related thing seems impossible to reconcile with the layout you&amp;rsquo;re staring at, the answer is almost always &amp;ldquo;the layout is a fiction, the registers are real.&amp;rdquo; As simple as this sounds, you&amp;rsquo;ll see how this can be a huge point of confusion later.&lt;/p>
&lt;h2 id="layout-hell-row-and-column-major">Layout Hell: Row and Column Major&lt;/h2>
&lt;p>Welcome to layout hell. The layout intro from earlier probably seemed easy enough, but you&amp;rsquo;ll realize that 90% of the difficulty in understanding CuTe comes from layout algebra. You might think you understand easy concepts like row-major or column-major, but I&amp;rsquo;m here to tell you that unless you&amp;rsquo;ve sat down and drawn these stupid squares over and over again, you probably don&amp;rsquo;t.&lt;/p>
&lt;h3 id="row-major-the-first-layer-of-hell">Row-Major: The First Layer of Hell&lt;/h3>
&lt;p>A standard human math matrix follows a row-major paradigm. There are M rows and N columns, and element $(i, j)$ points to the element in the &lt;code>i-th&lt;/code> row along M and &lt;code>j-th&lt;/code> column along N. If you&amp;rsquo;ve made arrays in most programming languages like C, Go, or numpy/torch, it&amp;rsquo;s precisely the same. Here&amp;rsquo;s a 4x6 row-major matrix, zero-indexed. We&amp;rsquo;ll call this the &lt;strong>logical view&lt;/strong>.&lt;/p>
&lt;p>$$
\begin{bmatrix}
a_{00} &amp;amp; a_{01} &amp;amp; a_{02} &amp;amp; a_{03} &amp;amp; a_{04} &amp;amp; a_{05} \\
a_{10} &amp;amp; a_{11} &amp;amp; a_{12} &amp;amp; a_{13} &amp;amp; a_{14} &amp;amp; a_{15} \\
a_{20} &amp;amp; a_{21} &amp;amp; a_{22} &amp;amp; a_{23} &amp;amp; a_{24} &amp;amp; a_{25} \\
a_{30} &amp;amp; a_{31} &amp;amp; a_{32} &amp;amp; a_{33} &amp;amp; a_{34} &amp;amp; a_{35}
\end{bmatrix}
$$&lt;/p>
&lt;blockquote>
&lt;p>Math majors might find this zero-indexing sacrilegious, but I&amp;rsquo;d imagine they&amp;rsquo;re not reading this blog anyway.&lt;/p>
&lt;/blockquote>
&lt;p>In the context of programming languages and memory, row-major also means that the items in each row are contiguous in memory&amp;ndash;i.e. contiguous &lt;em>along&lt;/em> the columns. This means $a_{i,j}$ is NEXT TO $a_{i,j+1}$ in memory. For a 2D row-major tensor (M, N), the N-stride is the &lt;em>innermost&lt;/em> dimension&amp;ndash;the one where elements are adjacent in memory. The N-stride is always $1$. The M-stride is simply the number of columns N; to get to the next row, you offset by the N adjacent elements in the current row.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// for row-major of shape 2,8
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">shape&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_8&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// the N-stride is always 1.
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// With 8 columns per row, the m-stride is 8
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">row_major_stride&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Extrapolating to an N-D row-major tensor of shape $(d_{n-1}, d_{n-2}, \dots, d_0)$: the 0th dim stride is 1, the 1st dim stride is $d_{0}$. For each subsequent dimension, we have to step by the number of elements in the entire block inside&amp;ndash;the second dim has $d_{1}\cdot d_{0}$ values for each of its &amp;ldquo;columns.&amp;rdquo; Therefore:&lt;/p>
&lt;p>$$\text{stride}(x) = \Pi_{i=0}^{x-1} d_i$$&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">shape&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">7&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">9&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">row_major_stride&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">630&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">126&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">18&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Each stride is just the size of the next dimension. Easy enough.&lt;/p>
&lt;h3 id="column-major-one-bigger-step-into-the-pit">Column-Major: One Bigger Step into the Pit&lt;/h3>
&lt;p>Column-major paradigms are far less common (e.g. Fortran, MATLAB) and assume that columns are adjacent in memory. Before things become more confusing, let&amp;rsquo;s compare PyTorch and MATLAB using a 2x4 example tensor:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># python&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">A&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">([[&lt;/span>&lt;span class="mi">9&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">6&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="o">-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">7&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">]])&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-matlab" data-lang="matlab">&lt;span class="line">&lt;span class="cl">&lt;span class="c">% MATLAB&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">A&lt;/span> &lt;span class="p">=&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="mi">9&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">6&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="o">-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">7&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>When you create &lt;code>A&lt;/code> in torch, it allocates an 8-int memory chunk that stores the values in the order we gave them: &lt;code>[9,2,4,6,-1,3,7,0]&lt;/code>. On the other hand, MATLAB allocates the same 8-int memory chunk but stores the values along the columns instead: &lt;code>[9,-1,2,3,4,7,6,0]&lt;/code>. When we index &lt;code>(i, j)&lt;/code> into torch and MATLAB (bless their souls for being 1-indexed), we obtain the same value:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-matlab" data-lang="matlab">&lt;span class="line">&lt;span class="cl">&lt;span class="c">% some MATLAB version that&amp;#39;s 0-indexed, technically A(2, 3)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">A&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">=&lt;/span> &lt;span class="mi">7&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">A&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">7&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>But now we realize the strides &lt;em>cannot&lt;/em> be the same. We know our row-major stride must be &lt;code>Stride&amp;lt;4, 1&amp;gt;&lt;/code>. Following our formula:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># i*stride_row + j*stride_col&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">offset&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">1&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">4&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">2&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">6&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># torch: 6th index of [9,2,4,6,-1,3,7,0] is 7&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># matlab: 6th index of [9,-1,2,3,4,7,6,0] is 6&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>The stride cannot be the same because the underlying memory is not the same. In torch, values in each &lt;em>row&lt;/em> are adjacent (9 is next to 2), but in MATLAB, values in each &lt;em>column&lt;/em> are adjacent in memory (9 is next to -1). So in our column-major view, to get to the next row we just step along the column, which we established is contiguous. Therefore, the innermost stride is now the &lt;em>leftmost&lt;/em> index. To get to the next column, we step by the number of values in a full column, which is the size of the rows. Let&amp;rsquo;s redo our example from &lt;a class="link" href="#row-major-the-first-layer-of-hell" >above&lt;/a>:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;span class="lnt">8
&lt;/span>&lt;span class="lnt">9
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">shape&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_8&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">row_major_stride&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// new innnermost stride is 1, full column is size _2
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">col_major_stride&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">//////////////////////////////////////////////
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">shape&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">7&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">9&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">row_major_stride&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">630&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">126&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">18&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">col_major_stride&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">15&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">105&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">945&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;blockquote>
&lt;p>&lt;strong>Tip&lt;/strong>: Telling row-major vs col-major apart is easy. Any matrix with leftmost stride 1 (e.g. &lt;code>(1, 8)&lt;/code>) is column-major and any matrix with rightmost stride 1 (e.g. &lt;code>(8, 1)&lt;/code>) is row-major.&lt;/p>
&lt;/blockquote>
&lt;h2 id="indexing-hell-in-what-context">Indexing Hell: In What Context?&lt;/h2>
&lt;p>When we toggle between row- and column-major, we expect the indices to have the same meaning. For some 2D matrix &lt;code>A[i][j]&lt;/code>, we want the ith index to refer to the ith row and the jth index to refer to the jth column. When we iterate through the row-major or col-major version of A, we should get the exact same numbers because &lt;code>A_row_major[i][j] == A_col_major[i][j]&lt;/code>.&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: Although they return the same numbers, depending on the order of iteration, one way will be more inefficient as it jumps from address to address instead of iterating contiguously. For example, &lt;code>for i...for j&lt;/code> is great for row-major but potentially a cache disaster for col-major. This only applies to tensors stored in memory. In the register file, indexing is purely an abstraction.&lt;/p>
&lt;/blockquote>
&lt;p>So what do I mean by &amp;ldquo;in what context?&amp;rdquo; So far, we&amp;rsquo;ve been aiming to create an equivalent representation of A via a row- or column-major format. But for CUDA kernels, there is no equivalence&amp;ndash;we are given some tensors in a predefined row- or column-major format. FA2 expects Q, K, V to be contiguous &lt;em>along&lt;/em> &lt;code>head_dim&lt;/code>, i.e. &lt;code>(seqlen, head_dim)&lt;/code>. This means they are &lt;em>row-major&lt;/em> with respect to each &lt;em>token&lt;/em>&amp;ndash;each token is one row in the matrix. In this case, there is only one valid way to load the data&amp;ndash;the row-major way, since the underlying data is &lt;em>already fixed&lt;/em>.&lt;/p>
&lt;h3 id="interpreting-fixed-data">Interpreting Fixed Data&lt;/h3>
&lt;p>In our new view, we are reading some predefined data like we did in the &lt;a class="link" href="#tensors" >tensor section&lt;/a>, so now let&amp;rsquo;s make some sense of it. Let&amp;rsquo;s take a look at row-major vs. col-major indexing for the shape (2, 8).&lt;/p>
&lt;blockquote>
&lt;p>You can leverage the &lt;code>print_layout()&lt;/code> function from &lt;a class="link" href="#layouts-shapes-and-strides" >earlier&lt;/a> to view this in your shell.&lt;/p>
&lt;/blockquote>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/row_col_major.png"
width="2344"
height="1308"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/row_col_major_hu7363f3ddb81f2e4ee7ea8b7de3601ec8_208236_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/row_col_major_hu7363f3ddb81f2e4ee7ea8b7de3601ec8_208236_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/row_col_major_hu7363f3ddb81f2e4ee7ea8b7de3601ec8_208236_2048x0_resize_lanczos_3.png 2048w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/row_col_major.png 2344w"
loading="lazy"
alt="Row vs. col major indexing on a layout with shape (2, 8)."
class="gallery-image"
data-flex-grow="179"
data-flex-basis="430px"
>
&lt;/p>
&lt;p>In this graphic, we index into the tensors via &lt;code>(i, j)&lt;/code> labels, where the number inside the square refers to the &lt;em>original index in the underlying memory&lt;/em> (it is not the linear ordering). In the row-major view, element 1 is in the same row as element 0. In the column-major view, element 1 is in the same column as element 0. When it comes to indexing, it&amp;rsquo;s often in your best interest to separate the ideas of contiguity, indexing, and reality. It&amp;rsquo;s often best to think in terms of the strides and offsets for your expected memory layout, since applying a row-major layout to column-major memory (and vice versa) no longer makes any physical sense. We&amp;rsquo;ll really have to grapple with this later.&lt;/p>
&lt;blockquote>
&lt;p>For example, if you read column-major memory using a row-major format, what does a &amp;ldquo;column&amp;rdquo; even mean? You might drive yourself mad trying to figure out what a &amp;ldquo;row&amp;rdquo; and &amp;ldquo;column&amp;rdquo; mean because they mean different things with respect to different memory views, tensors, and layouts.&lt;/p>
&lt;/blockquote>
&lt;h3 id="cute-default-is-column-major">CuTe Default is Column-Major&lt;/h3>
&lt;p>By default, CuTe resolves to column-major layouts. This was a choice NVIDIA made for whatever reason, and you should just accept it. If you create a layout with a shape but no stride, it will default to the column-major stride. When dealing with your data layout, always specify the stride for clarity. CuTe provides two primitives that make this slightly easier: &lt;code>GenRowMajor{}&lt;/code> and &lt;code>GenColMajor{}&lt;/code>.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// Stride(8, 1)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">layout&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_4&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_8&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">GenRowMajor&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// Stride(1, 4)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">layout1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_4&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_8&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">GenColMajor&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Even if you specify all your layouts properly, there are some internal workings where you might see column-major layouts pop up (e.g. tiling, Atoms, etc.). You should specify strides as much as possible to avoid confusing yourself when you start to get weird errors.&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: The column-major default has &lt;strong>nothing to do with your underlying data&lt;/strong>. It&amp;rsquo;s just a consistent indexing pattern CuTe chose.&lt;/p>
&lt;/blockquote>
&lt;p>Our &lt;code>(seqlen, head_dim)&lt;/code> Q/K/V layout (from &lt;a class="link" href="#design-choices" >Design Choices&lt;/a>) is row-major along the sequence, which lines up with PyTorch/JAX defaults and with Ampere&amp;rsquo;s row-major-oriented tensor ops &amp;ndash; so even though CuTe defaults to col-major indexing, our data and our copies will all be row-major.&lt;/p>
&lt;h3 id="linear-indexing-colex-indexing">Linear Indexing: Colex Indexing&lt;/h3>
&lt;p>With our (2,8) shape layout from earlier, CuTe allows us to index it with just one index-value, essentially treating &lt;code>(2, 3)&lt;/code> as a flat 6-element array&amp;ndash;we refer to this as &lt;strong>linear indexing&lt;/strong>. A programming language like C allows you to do the same:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="kt">int&lt;/span> &lt;span class="n">arr&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">][&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">{&lt;/span>&lt;span class="mi">9&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">},&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">{&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">6&lt;/span>&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kt">int&lt;/span> &lt;span class="n">val&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">arr&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">];&lt;/span> &lt;span class="c1">// val = 0
&lt;/span>&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>However, the way C and CuTe handle this internally is very different. C just treats the index &lt;code>[4]&lt;/code> as a memory offset; using its row-major memory layout, the 4th index grabs the 4th offset in memory, which lands in the second row and second element. In this system, the value at &lt;code>(1, 0)&lt;/code> is &amp;ldquo;further along&amp;rdquo; than &lt;code>(0, 1)&lt;/code>&amp;ndash;i.e. as you move from left to right and top to bottom, you increase in the order of access. We call this &lt;em>lexicographical order&lt;/em>. In lexicographical ordering, the index &lt;code>arr[4]&lt;/code> on shape (2,3) maps to &lt;code>arr[1][1]&lt;/code>.&lt;/p>
&lt;blockquote>
&lt;p>In English, we read from left to right, top to bottom&amp;ndash;that&amp;rsquo;s where this ordering gets its name. You can think of it as simply row-major indexing. Since C just works with offsets rather than linearly mapping index &lt;code>[4]&lt;/code> to &lt;code>[1][1]&lt;/code>, we typically refer to it as &lt;strong>flat indexing&lt;/strong>.&lt;/p>
&lt;/blockquote>
&lt;p>On the other hand, CuTe uses &lt;strong>colexicographical indexing&lt;/strong> (i.e. colex indexing), which is the opposite&amp;ndash;order increases first from top to bottom, and then left to right. In this transposed view, index &lt;code>(1, 0)&lt;/code> is adjacent to &lt;code>(0, 0)&lt;/code> and is &amp;ldquo;before&amp;rdquo; the index &lt;code>(0, 1)&lt;/code>. As before, it&amp;rsquo;s pretty much just column-major order. It&amp;rsquo;s CuTe&amp;rsquo;s way of consistently enforcing 1D-&amp;gt;N-dimensional indexing across layout algebra.&lt;/p>
&lt;p>The difference in CuTe is it is intentional about this order, unlike C. In C, the order is a side effect of the memory offset. In CuTe, the compiler actually performs the conversion between 1D and 2D. For example, if we index tensor &lt;code>t(idx)&lt;/code> for a tensor of shape &lt;code>(M, N)&lt;/code>, the index split becomes:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Example, shape (2, 3), idx 4&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># [[9, 2, -1],[3, 0, 6]]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># intentional colex indexing&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">i&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">idx&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="n">M&lt;/span> &lt;span class="c1"># 4 % 2 = 0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">j&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">idx&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">M&lt;/span> &lt;span class="c1"># 4 / 2 = 2&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># a[0][2] = -1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># typical lex indexing&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># C just uses memory offset, which is functionally&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># the same for its row-major mem layout&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">i&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">idx&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="n">N&lt;/span> &lt;span class="c1"># 4 % 3 = 1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">j&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">idx&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">N&lt;/span> &lt;span class="c1"># 4 / 3 = 1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># a[1][1] = 0&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>The default colex indexing goes hand-in-hand with the column-major layout standard. It&amp;rsquo;s for consistency and has no practical meaning outside of this. But it does mean that if you&amp;rsquo;re working with row-major layouts, &lt;em>you still have to use colex indexing if you&amp;rsquo;re indexing linearly into a multidimensional tensor&lt;/em>. Indexing it like you would in C or torch will simply produce wrong results, and you might sit there twiddling your thumbs wondering why your code isn&amp;rsquo;t working. Don&amp;rsquo;t worry, it happens to the best of us.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// linear printing colex example: it&amp;#39;s weird
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kt">int&lt;/span> &lt;span class="o">*&lt;/span>&lt;span class="n">data&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">6&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">7&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">9&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">11&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">12&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">13&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">14&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">15&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">16&lt;/span>&lt;span class="p">};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">layout&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_8&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">tensor&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">layout&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="mi">16&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="o">++&lt;/span>&lt;span class="n">i&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">i&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">printf&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s">&amp;#34;, &amp;#34;&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// output: 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16
&lt;/span>&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="composedhierarchical-layouts">Composed/Hierarchical Layouts&lt;/h2>
&lt;p>Where you might run into colex indexing issues is if you&amp;rsquo;re dealing with &lt;strong>hierarchical layouts&lt;/strong>. CuTe lets us nest layouts for more granular layout interpretations that wouldn&amp;rsquo;t be possible with non-nested layouts. For example, we can easily reinterpret a flat tensor of shape &lt;code>(8, )&lt;/code> as &lt;code>(2, 4)&lt;/code> or &lt;code>(4, 2)&lt;/code> in CuTe, but let&amp;rsquo;s take this a step further:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// (_8, _4): (_1, _8)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">l1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">_4&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// ((_2, _4), _4):((_1, _2), _8)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">l2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_4&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_4&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// Toy example
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="o">*&lt;/span>&lt;span class="n">a&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="p">...,&lt;/span> &lt;span class="mi">31&lt;/span>&lt;span class="p">};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// [0, 1, 2, 3, ..., 7]
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// [8, 9, 10, 11, ..., 15]
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// [16, 17, 18, 19, ..., 23]
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// [24, 25, 26, 27, ..., 31]
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="n">t1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">a&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">l1&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">t2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">a&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">l2&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kt">int&lt;/span> &lt;span class="n">v1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">t1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="c1">// 19
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">v2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">t2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="c1">// 19
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">v3&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">t2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_coord&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="c1">// 19
&lt;/span>&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>In &lt;code>l2&lt;/code>, we iterate through the inner (left) shape in groups of 2, column-major by default. In this case, the composed layout is purely decorative. We can still address a tensor with layout &lt;code>l1&lt;/code> or &lt;code>l2&lt;/code> with two coordinates &lt;code>(i, j)&lt;/code>, and CuTe maps the translation underneath. However, we are now also grouping the inner dimension as four groups of two. We&amp;rsquo;ll see nested layouts become a powerful tool during the &lt;a class="link" href="#fragment-reshape" >MMA tiling reshape&lt;/a>, where we cannot simply flatten a 3D tensor into a non-nested 2D tensor for our specific access pattern.&lt;/p>
&lt;h1 id="the-beginning-cute-copy-then-cry">The Beginning: CuTe, Copy, then Cry&lt;/h1>
&lt;p>Okay, with the CuTe vomit over, let&amp;rsquo;s get started with FA2!&lt;/p>
&lt;p>Much of the learning path from here is nonlinear &amp;ndash; some concepts will click immediately and some will only make sense in hindsight. You have to be comfortable accepting certain things early; the full understanding only kicks in once you&amp;rsquo;ve hit a mental snag. I&amp;rsquo;ll tackle some concepts up front and gloss over some others, but these choices are intentional &amp;ndash; you won&amp;rsquo;t always have the full context up-front when you are trying to do something. I spent hours thinking I understood something only to realize I didn&amp;rsquo;t understand a week later, and sometimes, it would take five of these I understand/I definitely don&amp;rsquo;t understand cycles before everything fell into place. Let&amp;rsquo;s dive in before you have time to reconsider 😇.&lt;/p>
&lt;h2 id="a100-ampere-specs">A100 (Ampere) Specs&lt;/h2>
&lt;p>The entire point of FA2 or even GPU optimization in general is to maximize compute by overlapping it with memory loads. Here are the memory and card specs of A100 GPU (Ampere):&lt;/p>
&lt;table>
&lt;thead>
&lt;tr>
&lt;th style="text-align:left">Storage Level&lt;/th>
&lt;th style="text-align:left">Latency (Clock Cycles)&lt;/th>
&lt;th style="text-align:left">Magnitude Slower than Registers&lt;/th>
&lt;/tr>
&lt;/thead>
&lt;tbody>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>Registers&lt;/strong>&lt;/td>
&lt;td style="text-align:left">~1 cycle&lt;/td>
&lt;td style="text-align:left">—&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>Shared Memory (SMEM) / L1&lt;/strong>&lt;/td>
&lt;td style="text-align:left">~20–30 cycles&lt;/td>
&lt;td style="text-align:left">~25x&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>L2 Cache&lt;/strong>&lt;/td>
&lt;td style="text-align:left">~200 cycles&lt;/td>
&lt;td style="text-align:left">~200x&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>HBM2e (Main Memory)&lt;/strong>&lt;/td>
&lt;td style="text-align:left">~400–600+ cycles&lt;/td>
&lt;td style="text-align:left">&lt;strong>~500x+&lt;/strong>&lt;/td>
&lt;/tr>
&lt;/tbody>
&lt;/table>
&lt;table>
&lt;thead>
&lt;tr>
&lt;th style="text-align:left">Feature&lt;/th>
&lt;th style="text-align:left">Specification&lt;/th>
&lt;/tr>
&lt;/thead>
&lt;tbody>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>Total SMs&lt;/strong>&lt;/td>
&lt;td style="text-align:left">108 (SXM4) / 128 (Full Die)&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>CUDA Cores per SM&lt;/strong>&lt;/td>
&lt;td style="text-align:left">64 (FP32)&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>Max Threads per SM&lt;/strong>&lt;/td>
&lt;td style="text-align:left">2048&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>Max Warps per SM&lt;/strong>&lt;/td>
&lt;td style="text-align:left">64&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>Max Blocks per SM&lt;/strong>&lt;/td>
&lt;td style="text-align:left">32&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>Registers per SM&lt;/strong>&lt;/td>
&lt;td style="text-align:left">65,536 (32-bit)&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>Max Registers per Thread&lt;/strong>&lt;/td>
&lt;td style="text-align:left">255&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>Max Shared Memory per SM&lt;/strong>&lt;/td>
&lt;td style="text-align:left">164 KB&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>Max Shared Memory per Block&lt;/strong>&lt;/td>
&lt;td style="text-align:left">163 KB&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>L1 Cache (Combined with SMEM)&lt;/strong>&lt;/td>
&lt;td style="text-align:left">192 KB total pool per SM&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">&lt;strong>L2 Cache&lt;/strong>&lt;/td>
&lt;td style="text-align:left">40 MB or 80 MB&lt;/td>
&lt;/tr>
&lt;/tbody>
&lt;/table>
&lt;h2 id="gmem-smem-async-copying">GMEM-&amp;gt;SMEM (Async Copying)&lt;/h2>
&lt;p>A rule of thumb is to have approximately &lt;strong>150-200 FLOPs per byte loaded from HBM&lt;/strong>. Although this particular number is quite arbitrary depending on your kernel or GPU, the universal theme is to overlap loads/stores with your actual compute.&lt;/p>
&lt;p>Since NVIDIA introduced the Ampere architecture, we can take advantage of asynchronous copying from GMEM-&amp;gt;SMEM to help us overlap our tile fetches with compute. Before, you might have had to wait hundreds or thousands of cycles for your bytes to hit SMEM; on Ampere, we can issue some loads and immediately begin doing useful work while the memory loads in the background.&lt;/p>
&lt;p>The async design pattern is quite simple:&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/async_pipeline.png"
width="2308"
height="1284"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/async_pipeline_hu246a953f40b88517180cbe2b4a521786_239370_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/async_pipeline_hu246a953f40b88517180cbe2b4a521786_239370_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/async_pipeline_hu246a953f40b88517180cbe2b4a521786_239370_2048x0_resize_lanczos_3.png 2048w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/async_pipeline.png 2308w"
loading="lazy"
alt="Fetch next data, do stuff with current data, wait for new data, and repeat."
class="gallery-image"
data-flex-grow="179"
data-flex-basis="431px"
>
&lt;/p>
&lt;p>We&amp;rsquo;ll cover how we apply this pattern to Q, K, V later on. There are some small CuTe details to be aware of, but the overall idea is exactly the same.&lt;/p>
&lt;p>Although you might think we can kind of async set-and-forget, there are two important concepts we need to be aware of that could potentially crush our performance if we&amp;rsquo;re not careful:&lt;/p>
&lt;h3 id="vectorized-and-coalesced-loads">Vectorized and Coalesced Loads&lt;/h3>
&lt;p>These two concepts are &lt;strong>vectorized loads&lt;/strong> and &lt;strong>coalesced loads&lt;/strong>. They are very similar in meaning and are often a point of confusion, so let&amp;rsquo;s break them down here:&lt;/p>
&lt;ul>
&lt;li>&lt;strong>Vectorized Loads&lt;/strong>: A &lt;em>thread&lt;/em> loading as much data as it can in one &lt;em>instruction&lt;/em>. Since we&amp;rsquo;re working with fp16, we could naively load one 16-bit number at a time. However, all NVIDIA chips today support a 128-bit load instruction &lt;em>per-thread&lt;/em>: &lt;code>LDG.E.128&lt;/code> (and its SMEM counterpart &lt;code>LDS.E.128&lt;/code>), which can load 8 fp16 numbers in one go. Memory transactions are funny in that a 16-bit load and a 128-bit load take the same amount of time, so if we load 16 bits at a time, we immediately slash our performance by 8x. So instead, when we can, we load 128 bits at a time and decompose it into 8 halfs (1 fp16 = 1 half). &lt;strong>Note&lt;/strong>: the memory controller almost always loads the same amount of data no matter the instruction, it essentially just throws away the data you don&amp;rsquo;t use, e.g. takes 16 bits from a 128-bit load.&lt;/li>
&lt;li>&lt;strong>Coalesced Loads&lt;/strong>: A &lt;em>group of threads&lt;/em> loading as much data as it can in one &lt;em>transaction&lt;/em>. GPUs never fetch from HBM just one byte at a time; they can fetch a whole 32, 64, or max 128-byte chunk in one go (i.e. the &lt;strong>transaction size&lt;/strong>). When this thread group loads a contiguous 128-byte chunk, the memory controller clears the entire block of data at once. Furthermore, this block fully saturates an L2 cache line, making any subsequent cache accesses more efficient. If all 32 threads in the warp are each fetching some random chunk scattered across memory, then the memory controller has to issue 32 separate transactions, immediately crushing your performance, hopes, and dreams. Note: coalescing is about the maximum bandwidth of the memory controller itself&amp;ndash;it has no relation to instructions or how many threads are participating in a load or store. It simply means whether or not we ask for a 128-byte chunk at once. You might notice that 32 threads and 128-bit &lt;em>instruction&lt;/em> loads is 512 bytes, four times the bandwidth. We&amp;rsquo;ll cover how this works in the next section.&lt;/li>
&lt;/ul>
&lt;blockquote>
&lt;p>&lt;strong>Tip&lt;/strong>: Here&amp;rsquo;s how you can figure out which one fits your scenario:&lt;/p>
&lt;table>
&lt;thead>
&lt;tr>
&lt;th style="text-align:left">Question / Scenario&lt;/th>
&lt;th style="text-align:left">&amp;ndash;&lt;/th>
&lt;/tr>
&lt;/thead>
&lt;tbody>
&lt;tr>
&lt;td style="text-align:left">Can each thread load 128-bits at one time with my data layout?&lt;/td>
&lt;td style="text-align:left">&lt;strong>Vectorization&lt;/strong>&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">Can I load 128 bytes from HBM at a time?&lt;/td>
&lt;td style="text-align:left">&lt;strong>Coalescing&lt;/strong>&lt;/td>
&lt;/tr>
&lt;/tbody>
&lt;/table>
&lt;/blockquote>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/vec.png"
width="1814"
height="1274"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/vec_hu330d518e49285b7068a4fe6df27d5816_171152_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/vec_hu330d518e49285b7068a4fe6df27d5816_171152_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/vec.png 1814w"
loading="lazy"
alt="Vectorized load example. Can issue 4 fp32 load instructions or just 1 128-bit load and reinterpret as fp32. Byte-addressed, so 0x4 address increment per float."
class="gallery-image"
data-flex-grow="142"
data-flex-basis="341px"
>
&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/coalesce.png"
width="1788"
height="984"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/coalesce_hu841b47aa88dd1af49e00c1cc3f4212d4_342432_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/coalesce_hu841b47aa88dd1af49e00c1cc3f4212d4_342432_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/coalesce.png 1788w"
loading="lazy"
alt="Coalesced load example. Four threads want 128-bits&amp;ndash;making a 64-byte contiguous chunk. The memory controller combines it into one transaction."
class="gallery-image"
data-flex-grow="181"
data-flex-basis="436px"
>
&lt;/p>
&lt;p>Both vectorized and coalesced loads expect the data to be contiguous (e.g. 128 bits and 128 bytes, respectively). If your data is scattered, you might not be able to leverage the full benefit of vectorization and coalescing. However, it&amp;rsquo;s possible that loading 64 bytes or 64 bits at a time could be good enough for your purpose. If memory becomes a bottleneck, you can always consider reformatting the data or loading out of order, as long as your downstream compute handles the data correctly.&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note: Memory coalescing only applies to GMEM/HBM&lt;/strong>, while vectorization applies to both GMEM and SMEM, although in slightly different ways. In both cases, we&amp;rsquo;re reducing instruction pressure and increasing our instruction-level parallelism (ILP). We&amp;rsquo;ll cover more details about bank conflicts and swizzling in our &lt;a class="link" href="#smem-registers" >SMEM-&amp;gt;register section&lt;/a> later.&lt;/p>
&lt;/blockquote>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: Don&amp;rsquo;t get confused between &lt;em>vectorized loads&lt;/em> and &lt;em>compute vectorization&lt;/em>. Although they have the same name, vectorized loads are about memory throughput while vectorized compute is about parallel computation. For example, numpy compute vectorization leverages SIMD CPU instructions to add matrices in one clock cycle. A GPU thread bundles a bunch of data into one load instruction to leverage higher memory bandwidth. Similar concept, different meanings depending on context.&lt;/p>
&lt;/blockquote>
&lt;h3 id="copy-atoms">Copy Atoms&lt;/h3>
&lt;p>Every NVIDIA GPU has a boatload of copy instructions&amp;ndash;you can fetch 32 bytes, 64 bytes, one byte, synchronously or asynchronously. CuTe neatly packages these copy instructions into a core piece called an &lt;code>Atom&lt;/code>. These &amp;ldquo;atomic&amp;rdquo; pieces are the core hardware instructions that you eventually pass to the &lt;code>copy&lt;/code> function so it knows which instruction to use to copy your data.&lt;/p>
&lt;p>Ampere has a specific asynchronous &lt;code>Copy_Atom&lt;/code> with the architecture name &lt;code>SM_80&lt;/code>: &lt;code>SM80_CP_ASYNC_CACHEGLOBAL&amp;lt;bit_size&amp;gt;&lt;/code> or &lt;code>SM80_CP_ASYNC_CACHEALWAYS&amp;lt;bit_size&amp;gt;&lt;/code>. The &lt;code>cache_global&lt;/code> and &lt;code>cache_always&lt;/code> map to the PTX&lt;sup id="fnref1:3">&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref">3&lt;/a>&lt;/sup> instructions &lt;code>cp.async.cg.shared.global&lt;/code> and &lt;code>cp.async.ca.shared.global&lt;/code>; &lt;code>cache_global&lt;/code> loads straight from L2 to the destination, skipping over L1 cache, while &lt;code>cache_always&lt;/code> also loads the data into L1. Most kernels use &lt;code>cache_always&lt;/code> by default because of improved spatial and temporal locality across threads. But in FA2, we never reference Q, K, or V again once they&amp;rsquo;re loaded into SMEM&amp;ndash;therefore, we can bypass the L1 cache, which is slightly faster. It also reduces thrashing at the L1 level and allows more important data to stay in-cache. In practice, this is a micro-optimization and not that important.&lt;/p>
&lt;p>The &lt;code>bit_size&lt;/code> supports up to 128-bit loads. &lt;strong>Bits&lt;/strong>, not bytes, since these atoms are viewed through the &lt;strong>thread perspective&lt;/strong>. Hence, our atom loads a total of $128 \text{ bits} \cdot 32 / 8 = 512 \text{ bytes}$. This means each 128-bit fetch across the 32 threads in a warp takes $512/128 = 4$ memory transactions in four &amp;ldquo;phases&amp;rdquo; (more on this later). For our purposes, we want that full coalesced 128-bit power using &lt;code>cache_global&lt;/code>. We can define the &lt;code>Copy_Atom&lt;/code> with the following code:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="cp">#include&lt;/span> &lt;span class="cpf">&amp;lt;cute/atom/copy_atom.hpp&amp;gt;&lt;/span>&lt;span class="cp">
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span>&lt;span class="k">using&lt;/span> &lt;span class="n">GmemCopyAtom&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Copy_Atom&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">SM80_CP_ASYNC_CACHEGLOBAL&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">uint128_t&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">half_t&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>We use the cute namespace types for robustness, and our source data type is fp16 (&lt;code>cute::half_t&lt;/code>). Each thread therefore loads $128/16=8$ halfs.&lt;/p>
&lt;h4 id="how-do-32-threads-load-128-bits-each">How do 32 threads load 128 bits each?&lt;/h4>
&lt;p>We have 32 threads in each warp loading 32 128-bit chunks in tandem, which is 512 total bytes, or 128 words&lt;sup id="fnref:4">&lt;a href="#fn:4" class="footnote-ref" role="doc-noteref">4&lt;/a>&lt;/sup>, or 4x32 bank accesses (see the &lt;a class="link" href="#bank-conflicts-and-smem-layout" >bank conflicts section&lt;/a> below). The GPU cannot physically load 512 bytes in one go &amp;ndash; both GMEM and SMEM can only move 128-bytes at a time&lt;sup id="fnref:5">&lt;a href="#fn:5" class="footnote-ref" role="doc-noteref">5&lt;/a>&lt;/sup>:&lt;/p>
&lt;ul>
&lt;li>GMEM can load &lt;strong>128 contiguous bytes&lt;/strong> in one memory transaction.&lt;/li>
&lt;li>SMEM can load any &lt;strong>128 bytes as long as there are no bank conflicts&lt;/strong>.&lt;/li>
&lt;li>Both assume the memory addresses are aligned to your data width (e.g. 128-bit load -&amp;gt; 16-byte aligned address).&lt;/li>
&lt;/ul>
&lt;p>As a result, GMEM/SMEM handles these 512-byte loads in four separate 128-byte memory transactions. Each &lt;em>transaction phase&lt;/em> provides 8-threads (also called a &lt;em>quarter-warp&lt;/em>) worth of data. The memory controller is smart enough to group the relevant addresses together to ensure each transaction uses the full 128-byte bandwith when possible (see &lt;a class="link" href="#vectorized-and-coalesced-loads" >vectorization/coalescing&lt;/a>).&lt;/p>
&lt;p>After each phase, each quarter-warp is handed its contiguous 8x128-bit (128-byte) block. So by design, our async copies perfectly copy our data using the full HBM bandwidth.&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: this behavior is handled at the &lt;strong>hardware level&lt;/strong>. CuTe doesn&amp;rsquo;t do some 4-iteration loop or anything &amp;ndash; it&amp;rsquo;s the memory controller&amp;rsquo;s job.&lt;/p>
&lt;/blockquote>
&lt;h3 id="tiled-copy">Tiled Copy&lt;/h3>
&lt;p>Even though each thread copies 128 bits, each thread block is usually working with a variable number of threads/warps. Given the 4 tensor cores per SM, 4 warps per block is typically a good choice for FA2. This means we have to determine how to copy each Q, K, V tile using these 128-bit async copies.&lt;/p>
&lt;p>CuTe uses &lt;code>Tiled_Copy&lt;/code>, which &amp;ldquo;tiles&amp;rdquo; the memory you are trying to copy (in this case, GMEM) in a structured way over your entire memory region. It outlines the &amp;ldquo;tiling strategy&amp;rdquo; that your threads will follow.&lt;/p>
&lt;blockquote>
&lt;p>Note that the &amp;ldquo;tiling&amp;rdquo; here is not the same as the Q, K, V tile. It&amp;rsquo;s tiling the memory layout, while our Q, K, V are tiles of our algorithm. Unfortunately in our case, it&amp;rsquo;s tiling&amp;hellip;our tiles.&lt;/p>
&lt;/blockquote>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// layouts are not filled in yet
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">using&lt;/span> &lt;span class="n">MyTiledCopy&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="k">decltype&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_tiled_copy&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Copy_Atom&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Atom&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">T&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="c1">// Atom
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="c1">// Thread layout (who)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">{}&lt;/span> &lt;span class="c1">// Value layout (what per thread)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>The tiled copy function &lt;code>make_tiled_copy&lt;/code> takes in the atom, the thread layout, and the values given to each thread. Our &lt;code>Copy_Atom&lt;/code> is a 128-bit wide chunk of 8 fp16 numbers, which is 8 values per thread. Given our row-major inputs, the value layout has to be &lt;code>Layout&amp;lt;Shape&amp;lt;_1, _8&amp;gt;&amp;gt;{}&lt;/code>. The other layout is the &lt;em>thread layout&lt;/em>, i.e. how you want to distribute your threads per tile. Assuming &lt;code>kNThreads=128&lt;/code>, we have to give each thread a 128-bit chunk. The stride determines which 128-thread tile of memory comes next. The easiest strategy is to simply spread the tiles across the columns and then the rows, essentially filling from the top (see image below).&lt;/p>
&lt;blockquote>
&lt;p>This gets slightly tricky here because of bank conflict optimization. Dao uses the same tiled copy setup for Q, K, V despite them having slightly different dimensions. We&amp;rsquo;ll revisit this when we talk about bank conflicts, but for now, assume our SMEM block is of shape &lt;code>(_, kBlockKSmem)&lt;/code>, where &lt;code>kBlockKSmem&lt;/code> is the column width for all 3 tensors. We can compute the layout as:&lt;/p>
&lt;/blockquote>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// pseudocode; assume static constexpr ints
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">halfs_per_128bit_load&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="k">sizeof&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">uint128_t&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="k">sizeof&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">half_t&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kt">int&lt;/span> &lt;span class="n">threads_per_row&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">kBlockKSmem&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">halfs_per_128bit_load&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kt">int&lt;/span> &lt;span class="n">num_thread_rows&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">kNThreads&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">threads_per_row&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kt">int&lt;/span> &lt;span class="n">num_thread_cols&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">threads_per_row&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>For &lt;code>kBlockKSmem=64&lt;/code>, each row is 64 halfs or 8 128-bit loads, so 8 threads per row. With 128 threads, we cover $128/8=16$ rows per tile. The stride is simple: the column stride should move by static &lt;code>_1{}&lt;/code> for the next 128-bit load. The row stride should move by the entire &lt;code>num_thread_cols&lt;/code> chunk to reach the next row. Hence, our &lt;code>Tiled_Copy&lt;/code> is:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// Since these are constexpr, we use statics!
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">using&lt;/span> &lt;span class="n">TiledCopyQKV&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="k">decltype&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_tiled_copy&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">GmemCopyAtom&lt;/span>&lt;span class="p">{},&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">num_thread_rows&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">num_thread_cols&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">num_thread_cols&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_8&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">{}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/gmem_tiling.png"
width="1828"
height="1186"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/gmem_tiling_hu8471f6b995ae5327f227a2a9278bf604_124679_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/gmem_tiling_hu8471f6b995ae5327f227a2a9278bf604_124679_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/gmem_tiling.png 1828w"
loading="lazy"
alt="The tiling strategy above, but assume two warps, not four to show how warps cycle. Each row is 8 threads, each square is a thread&amp;rsquo;s 128-bit load (8 fp16s). Each full tile is two warp-loads, four rows per warp, so 8 rows per tile."
class="gallery-image"
data-flex-grow="154"
data-flex-basis="369px"
>
&lt;/p>
&lt;p>The way to think about this is that this &lt;code>Tiled_Copy&lt;/code> is the tiling strategy for your source memory (GMEM in this case). All 128 threads load the first 128 contiguous 128-bit chunks, finish, then move onto the next 128 chunks until the entire GMEM section is copied. Even though this example is for a GMEM source, &lt;code>Tiled_Copy&lt;/code> works between GMEM, SMEM, and per-thread registers. It doesn&amp;rsquo;t know what anything is&amp;ndash;it&amp;rsquo;s just the floorplan, and we&amp;rsquo;re responsible for providing the expected input.&lt;/p>
&lt;h3 id="tiled-copy-source-and-destination">Tiled Copy, Source and Destination&lt;/h3>
&lt;p>Our &lt;code>Tiled_Copy&lt;/code> determines how our source is tiled, but we now have to configure the destination. The destination layout is determined by the destination tensor itself. The &lt;code>Tiled_Copy&lt;/code> simply places each thread&amp;rsquo;s data in the &amp;ldquo;same place&amp;rdquo; it was loaded from. The destination layout can essentially be anything as long as it is compatible with the &lt;code>Copy_Atom&lt;/code>. Since we have 128-bit loads/stores, the destination tensor layout must accept aligned 8-half blocks (more on this in swizzling). For now, we can ignore what the output tensor is. &lt;code>Tiled_Copy&lt;/code> has a specific pattern for copying between a source and a destination: a thread view, a partitioning step, and finally the copy.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// defining tiled copy
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">typename&lt;/span> &lt;span class="n">Traits&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">GmemTiledCopyQKV&lt;/span> &lt;span class="n">gmem_tiled_copy_QKV&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// what thread are we? let&amp;#39;s get the slice of the data
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// that belongs to thread tid
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">gmem_thr_copy_QKV&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">gmem_tiled_copy_QKV&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">get_thread_slice&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tid&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// partition thread Q gmem SOURCE tensor
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tQgQ&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">gmem_thr_copy_QKV&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_S&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">gQ&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// partition thread Q smem DEST tensor
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tQsQ&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">gmem_thr_copy_QKV&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_D&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sQ&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// copy op: (tiled_copy, source, dest)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">gmem_tiled_copy_QKV&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tQgQ&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tQsQ&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>In this example, assume &lt;code>gQ&lt;/code> and &lt;code>sQ&lt;/code> are correctly-defined GMEM and SMEM tensors. We first define our tiled copy blueprint. Then we get the thread slice of this tiled copy, which translates our global tiled copy object into the values this thread actually fetches. Then we partition the source and the destination, laying the thread blueprint on the source and destination tensors. Finally, we issue the copy operation.&lt;/p>
&lt;blockquote>
&lt;p>Example: Thread 0 takes the 0th (first) 128 bits, halfs 0-7. Then it takes the 128th 128-bit chunk, then the 256th, 384th, until the source is tiled. The intermediate thread tensor has shape &lt;code>((1, 8), M, N)&lt;/code> where M, N represent the tile and 1, 8 is the value layout. It may not be exactly this, but it doesn&amp;rsquo;t really matter as we don&amp;rsquo;t usually have to work with the intermediate partition.&lt;/p>
&lt;/blockquote>
&lt;h3 id="gmem-and-smem-tensors">GMEM and SMEM Tensors&lt;/h3>
&lt;p>Saved the easiest step for last. Let&amp;rsquo;s define the &lt;code>gX&lt;/code> and &lt;code>sX&lt;/code> tensors for GMEM and SMEM.&lt;/p>
&lt;p>CuTe provides a convenient API to retrieve the proper tensor tile from the source. It has the unfortunate side effect of being somewhat convoluted and ugly, but hey, it works.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// Assume we have a params struct that contains our source parameters
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// like pointers, dims, and strides
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// gmem
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="n">mQ&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_gmem_ptr&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="k">reinterpret_cast&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="k">const&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">half_t&lt;/span> &lt;span class="o">*&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">q_ptr&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">batch_idx&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">q_batch_stride&lt;/span> &lt;span class="o">+&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">head_idx&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">q_head_stride&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">seqlen_q&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">head_dim&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_stride&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">q_row_stride&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">gQ&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">local_tile&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">mQ&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockM&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kHeadDim&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{}),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_coord&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m_block&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// smem
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="n">sQ&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="k">reinterpret_cast&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">half_t&lt;/span> &lt;span class="o">*&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">smem&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">SmemLayoutQ&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">sK&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sQ&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sQ&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">SmemLayoutKV&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>This looks awful, but the mechanism is quite simple. Each thread block operates on a unique block of Q for some unique batch/head. We compute the batch and head index and offset into the Q tensor by the batch and head strides, arriving at that particular batch/head&amp;rsquo;s Q tensor. CuTe has primitives like &lt;code>make_gmem_ptr&lt;/code> and &lt;code>make_smem_ptr&lt;/code> to tell the underlying engine to issue the correct PTX instructions for copying between GMEM, SMEM, and the register file. We provide it a layout so we can easily call &lt;code>local_tile(tensor, tile_layout, coord)&lt;/code> to retrieve the tile of interest, in this case the &lt;code>m&lt;/code>-th block of Q. It takes in a &lt;code>Coord&lt;/code> which is the &lt;code>(i, j)&lt;/code>-th tile according to &lt;code>tile_layout&lt;/code>.&lt;/p>
&lt;p>We could just as easily have made the &lt;code>mQ&lt;/code> pointer point to the start of the batch/head dimension and local-tiled into BH as well as &lt;code>m_block&lt;/code>. The output PTX would be exactly the same&amp;ndash;it&amp;rsquo;s simply a matter of personal preference. The K and V GMEM tensors iterate over all blocks along the seqlen dimension, so their coord uses an underscore &lt;code>_&lt;/code> to signal this to the compiler.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">gK&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">local_tile&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">mK&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockN&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kHeadDim&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{}),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_coord&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="smem-registers">SMEM-&amp;gt;Registers&lt;/h2>
&lt;blockquote>
&lt;p>&lt;strong>Read this first: thread view vs. tile view.&lt;/strong> From here on out, we&amp;rsquo;re working with tensor-core fragments, and CuTe will start lying to you in a productive way. When CuTe shows you a fragment with shape &lt;code>((2,2), MMA_M, MMA_N)&lt;/code>, that is &lt;strong>not the shape of a tile&lt;/strong> — it&amp;rsquo;s the shape of &lt;em>one thread&amp;rsquo;s slice of every tile&lt;/em>. The &lt;code>(2,2)&lt;/code> is the 4 elements that thread holds in a single 16x8 atom; the &lt;code>MMA_M, MMA_N&lt;/code> count how many atoms tile across the full block. Every operation on a fragment tensor — every &lt;code>for r in size&amp;lt;0&amp;gt;(frag)&lt;/code>, every &lt;code>frag(i, j)&lt;/code> — is implicitly executed by all 32 threads of a warp in lock-step, each on its own values. CuTe abstractions (&lt;code>partition_fragment&lt;/code>, &lt;code>cute::gemm&lt;/code>, &lt;code>cute::copy&lt;/code>) make this look like normal tensor code, which is the exact source of confusion: the only place the thread view is &lt;em>visible&lt;/em> in code is the ~5 lines around &lt;code>get_thread_slice(tid)&lt;/code>. Whenever a layout stops making sense, ask &amp;ldquo;is this a tile shape or a thread-slice shape?&amp;rdquo; — it&amp;rsquo;s almost always the second.&lt;/p>
&lt;/blockquote>
&lt;p>We&amp;rsquo;ll issue a second &lt;code>Tiled_Copy&lt;/code> to copy from SMEM to the registers. The copy pattern is mostly the same, but instead of simply transferring memory from SMEM to the registers, we must format the SMEM and registers for the tensor core matrix multiply-add (MMA) instructions.&lt;/p>
&lt;p>Our first MMA GEMM is between Q and K. Since they are both in row-major format, the copy works quite easily without much overhead. We&amp;rsquo;ll get into the tensor core instructions shortly, but for now, all we need to know is that Ampere natively supports 16x8x16 (MxNxK) MMAs out of the box. Each tensor op has shape $(16\times 16) \times (16\times 8) = (16\times 8)$.&lt;/p>
&lt;p>$$C = A\times B + C$$&lt;/p>
&lt;p>Each warp does one MMA in one tensor core cycle, and the warps synchronize with one another to produce the final accumulated output. Each MMA is mapped to one warp, where A, B, and C are stored in &lt;strong>fragments&lt;/strong> across all 32 threads in registers. NVIDIA selects the register mapping for each architecture, which is conveniently defined in CuTe via the &lt;code>MMA_Atom&lt;/code> (more on this later). For now, all we know is that each thread must hold its share of A, B, and C (Q, K, accumulator) via the &lt;code>Tiled_Copy&lt;/code>.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">using&lt;/span> &lt;span class="n">SmemCopyAtom&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Copy_Atom&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">SM75_U32x4_LDSM_N&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">half_t&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Our copy atom this time leverages the &lt;code>LDSM_N&lt;/code> SASS&lt;sup id="fnref2:3">&lt;a href="#fn:3" class="footnote-ref" role="doc-noteref">3&lt;/a>&lt;/sup> instruction: Load from Shared Memory with the &amp;ldquo;N&amp;quot;ormal row-major/no-transpose layout. It moves 4x32-bit words = 128 bits per instruction, similar to our async load from before. However, this instruction is quite special&amp;ndash;it is &lt;em>specifically made for tensor core MMAs&lt;/em>. As we&amp;rsquo;ll see in the next section, the tensor cores require specific threads to have specific pieces of each fragment. Although each thread issues a 128-bit transfer, it &lt;em>does not necessarily end up with that data&lt;/em>. Instead, &lt;code>LDSM&lt;/code> performs a specialized hardware warp shuffle so that each thread ends up with the correct data.&lt;/p>
&lt;p>This instruction is also commonly referred to by its PTX counterpart, &lt;code>ldmatrix&lt;/code>:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-sass" data-lang="sass">&lt;span class="line">&lt;span class="cl">&lt;span class="nt">ldmatrix&lt;/span>&lt;span class="nc">.sync.aligned.shape.num&lt;/span>&lt;span class="err">{&lt;/span>&lt;span class="nc">.trans&lt;/span>&lt;span class="err">}{&lt;/span>&lt;span class="nc">.ss&lt;/span>&lt;span class="err">}&lt;/span>&lt;span class="nc">.type&lt;/span> &lt;span class="nt">r&lt;/span>&lt;span class="o">,&lt;/span> &lt;span class="o">[&lt;/span>&lt;span class="nt">p&lt;/span>&lt;span class="o">];&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="na">.shape&lt;/span>&lt;span class="o"> =&lt;/span> &lt;span class="err">{&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">m8n8&lt;/span>&lt;span class="err">};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="na">.num&lt;/span>&lt;span class="o"> =&lt;/span> &lt;span class="err">{&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">x1&lt;/span>&lt;span class="o">,&lt;/span> &lt;span class="o">.&lt;/span>&lt;span class="n">x2&lt;/span>&lt;span class="o">,&lt;/span> &lt;span class="o">.&lt;/span>&lt;span class="n">x4&lt;/span>&lt;span class="err">};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="na">.ss&lt;/span>&lt;span class="o"> =&lt;/span> &lt;span class="err">{&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">shared&lt;/span>&lt;span class="err">{&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">cta&lt;/span>&lt;span class="err">}};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="na">.type&lt;/span>&lt;span class="o"> =&lt;/span> &lt;span class="err">{&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">b16&lt;/span>&lt;span class="err">};&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Our specific copy atom maps to the &lt;code>ldmatrix...x4&lt;/code> variant, which loads an entire $4\times(8\times 8)=16\times 16$ fragment in one go. It drains through the &lt;a class="link" href="#how-do-32-threads-load-128-bits-each" >same four 128-byte transactions&lt;/a> as our GMEM async copy. However, unlike the GMEM copy, the &lt;code>LDSM&lt;/code> tiled copy has to be aware of the downstream MMA thread layout, which differs between fragments A, B, and C.&lt;/p>
&lt;blockquote>
&lt;p>We&amp;rsquo;ll cover more &lt;code>LDSM&lt;/code> details later when we use &lt;code>LDSM_T&lt;/code> for the &lt;a class="link" href="#ldsm-copy-atom" >V-copy&lt;/a>.&lt;/p>
&lt;/blockquote>
&lt;h3 id="tiled-mma">Tiled MMA&lt;/h3>
&lt;blockquote>
&lt;p>&lt;strong>Source:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/kernel_traits.cuh" target="_blank" rel="noopener"
>&lt;code>kernel_traits.cuh&lt;/code>&lt;/a>&lt;/p>
&lt;p>&lt;strong>Play:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/03_mma.cu" target="_blank" rel="noopener"
>&lt;code>scratch/03_mma.cu&lt;/code>&lt;/a> (single MMA, no SMEM), &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/04_mma.cu" target="_blank" rel="noopener"
>&lt;code>scratch/04_mma.cu&lt;/code>&lt;/a> (full pipeline, verified)&lt;/p>
&lt;/blockquote>
&lt;p>Getting deja vu yet? This time, we define the tiling for the MMA GEMM. We define the following tiled MMA atom:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// TN means transposed-normal for AxB. It&amp;#39;s a historical convention
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// that you can search up.
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// Practically, it means both A and B are row-major across M, N
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// i.e. K-dim is contiguous
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// (M, K), (N, K)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">using&lt;/span> &lt;span class="n">TiledMmaAtom&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MMA_Atom&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">SM80_16x8x16_F32F16F16F32_TN&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>You might wonder why 16x8x16 and not 16x16x16. Again, it&amp;rsquo;s a hardware design choice made by NVIDIA engineers. There are a few reasons:&lt;/p>
&lt;ol>
&lt;li>Less register pressure: B and C fragments are both 16x8, reducing the total register footprint compared to a 16x16 per warp.&lt;/li>
&lt;li>More register re-use. Each A tile is used twice per B and C tile, reducing the number of simultaneous register reads.&lt;/li>
&lt;li>Best &amp;ldquo;area of efficiency&amp;rdquo;. NVIDIA certainly tested many combos and found this size to be optimal.&lt;/li>
&lt;/ol>
&lt;p>This is by no means an exhaustive list, and tensor core shapes change generation-to-generation for a multitude of reasons. It&amp;rsquo;s best to just use it as-is instead of wondering all day why it is the way it is. The TiledMMA atom conveniently defines which threads get which chunks and which registers are used for the MMA, which we can see below:&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_atom.png"
width="1338"
height="1712"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_atom_hu2245a945016555384b8151cb2072505d_690927_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_atom_hu2245a945016555384b8151cb2072505d_690927_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_atom.png 1338w"
loading="lazy"
alt="MMA Atom thread layout. We can see each thread gets 32-bits (2 halfs) at a time. For each 16x16 tile, each thread has two half-pairs per row, and only 1 half-pair for each 16x8 tile."
class="gallery-image"
data-flex-grow="78"
data-flex-basis="187px"
>
&lt;/p>
&lt;p>With this info, let&amp;rsquo;s define the full &lt;code>Tiled_MMA&lt;/code>:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">using&lt;/span> &lt;span class="n">TiledMma&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">TiledMMA&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">MMA_Atom&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">SM80_16x8x16_F32F16F16F32_TN&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kNWarps&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Tile&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">16&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">kNWarps&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_16&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_16&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/tiled_mma.png"
width="2046"
height="1194"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/tiled_mma_hu6d0f913c83b5c5c4c0c7a54bbb63ffab_119902_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/tiled_mma_hu6d0f913c83b5c5c4c0c7a54bbb63ffab_119902_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/tiled_mma.png 2046w"
loading="lazy"
alt="Tiled MMA layout. Each solid color is one of four warps. If we had more rows/cols, the color pattern would repeat. The tiled layout for fragment B is pretty much the same as for fragment C, only with a size difference. Each 16x16 B-tile is composed of two fragments that are 1 N-tile adjacent for shape (K, N). Note that B is transposed in this visualization."
class="gallery-image"
data-flex-grow="171"
data-flex-basis="411px"
>
&lt;/p>
&lt;p>We chose 128 threads (4 warps) because each SM has 4 resident tensor cores&amp;ndash;a sensible choice to maximize MMA throughput. For the layout, we tile across the M-dimension (taking a slice from the left column of Q) and move across the K dimension. Each tile is &lt;code>kNWarps&lt;/code> stacked on top of each other; for a 16x8x16 MMA atom, our tile shape becomes $(M, N, K) = (16\cdot\text{kNWarps}, 16, 16)$.&lt;/p>
&lt;p>We flagged this M-tiling design choice in the &lt;a class="link" href="#basic-structure" >Basic Structure section&lt;/a>. The &lt;code>Layout&amp;lt;Shape&amp;lt;Int&amp;lt;kNWarps&amp;gt;, _1, _1&amp;gt;&amp;gt;&lt;/code> puts all the warps along M with &lt;code>_1&lt;/code> along N and K, which means &lt;strong>every warp owns whole rows of the output, never a horizontal slice of one&lt;/strong>. When we compute the per-row max and per-row sum during softmax, the values to be reduced live in registers within a single warp, so the reduction is a &lt;code>__shfl_xor_sync()&lt;/code> away — no SMEM staging, no thread-block sync. If we had tiled warps along N instead, that same reduction would have to cross warps and we&amp;rsquo;d be staging through SMEM with &lt;code>__syncthreads()&lt;/code> on every iteration, crushing our performance. This staging was a forced sticking point on original FlashAttention-1.&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: $N=16$, not 8, because we must aggregate across adjacent N-atoms to produce one 16x16 output tile due to the 16x8 asymmetry. That also means our &lt;code>LDSM&lt;/code> copy atom loads two K, V tiles per instruction. This works because our N-tiles are index-adjacent.&lt;/p>
&lt;/blockquote>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_in_to_out.png"
width="2000"
height="1202"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_in_to_out_hud13a14ae1d8af1c8036e837b47af987c_92685_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_in_to_out_hud13a14ae1d8af1c8036e837b47af987c_92685_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_in_to_out.png 2000w"
loading="lazy"
alt="(16,16) x (16,8) MMA produces an (16,8) output. MMA of one A tile with two adjacent B tiles &amp;#43; concatenation produces one (16,16) output tile."
class="gallery-image"
data-flex-grow="166"
data-flex-basis="399px"
>
&lt;/p>
&lt;h3 id="tiled-copy-a-b-and-c">Tiled Copy A, B, and C&lt;/h3>
&lt;h4 id="what-is-a-fragment">What is a Fragment?&lt;/h4>
&lt;p>SMEM-&amp;gt;register copies operate on fragments. As mentioned earlier, a fragment is simply each thread&amp;rsquo;s share of the A, B, or C matrix used in the tensor core. We can see which piece each thread gets from the layout in the previous section, although this will become clearer in your head once we begin to work with it in detail. Since we are tiling our Q, K, and V with these MMA atoms, each thread gets multiple fragments (see &lt;a class="link" href="#mma-shape" >MMA shape&lt;/a> later) based on the number of atoms it takes to tile our SMEM. As a result, CuTe provides &lt;code>partition_fragment_A/B/C()&lt;/code> functions to partition our SMEM depending on whether the tensor is A, B, or C in the MMA, since each role has a different thread layout.&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: We explore a huge caveat with &lt;code>partition_fragment()&lt;/code> when we discuss &lt;a class="link" href="#svtnoswizzle-the-no-op-nobody-caught" >the fragment shape for V&lt;/a>.&lt;/p>
&lt;/blockquote>
&lt;p>As covered in &lt;a class="link" href="#registers-arent-memory" >Registers Aren&amp;rsquo;t Memory&lt;/a>, the register fragment looks like a tensor but it&amp;rsquo;s a 1-1 mapping into the register file, not addressable memory. Keep that in mind in this section.&lt;/p>
&lt;h4 id="q-k-smem-register-tiled-copy">Q, K SMEM-&amp;gt;Register Tiled Copy&lt;/h4>
&lt;p>The code pattern is mostly the same as our GMEM-&amp;gt;SMEM copy, with some SMEM-&amp;gt;register specifics. Mainly, the tiled copy interacts with our tiled MMA. So first, we have to declare our tiled MMA and the destination fragment registers:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// create tiled MMA
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">tiled_mma&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">TiledMma&lt;/span>&lt;span class="p">{};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// partition the fragments
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">thr_mma&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tiled_mma&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">get_thread_slice&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tid&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tSrQ&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">thr_mma&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_fragment_A&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sQ&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tSrK&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">thr_mma&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_fragment_B&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sK&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// C does not need a slice of memory since
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// it is write-only. We can skip all the thread slicing
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// and partitioning and just get the fragments in
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// one go
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="n">acc_s&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">partition_fragment_C&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">tiled_mma&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockM&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockN&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// initialize with 0.0
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">clear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_s&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Next, we create the tiled copy and partition SMEM for the copy transaction.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// create Q, K tiled copy
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">smem_tiled_copy_Q&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tiled_copy_A&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">SmemCopyAtom&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">tiled_mma&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">smem_tiled_copy_K&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tiled_copy_B&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">SmemCopyAtom&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">tiled_mma&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// thread slice of MMA
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">smem_thr_copy_Q&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">smem_tiled_copy_Q&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">get_thread_slice&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tid&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">smem_thr_copy_K&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">smem_tiled_copy_K&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">get_thread_slice&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tid&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// partition SMEM
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// tSsQ = thread Score-smem Q
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">tSsQ&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">smem_thr_copy_Q&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_S&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sQ&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">tSsK&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">smem_thr_copy_K&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_S&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sK&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Notice we don&amp;rsquo;t partition the destination registers here &amp;ndash; only the SMEM source. We &lt;em>do&lt;/em> call &lt;code>retile_D&lt;/code> on the register fragment though. The full rule is explained in the next subsection.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tXrQ&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">smem_thr_copy_Q&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">retile_D&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tSrQ&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tXrK&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">smem_thr_copy_K&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">retile_D&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tSrK&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="partition-vs-retile">Partition vs. Retile&lt;/h3>
&lt;blockquote>
&lt;p>&lt;strong>Play:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/retile_viz.cu" target="_blank" rel="noopener"
>&lt;code>scratch/retile_viz.cu&lt;/code>&lt;/a> (visualize how &lt;code>retile_D&lt;/code>/&lt;code>retile_S&lt;/code> rebind register layouts)&lt;/p>
&lt;/blockquote>
&lt;p>Every tiled copy in this kernel boils down to one decision: do I &lt;code>partition&lt;/code> the source/destination, or &lt;code>retile&lt;/code> it? The answer depends entirely on whether the tensor is in shared/global memory or in registers, and whether it&amp;rsquo;s the source or destination of the copy.&lt;/p>
&lt;table>
&lt;thead>
&lt;tr>
&lt;th style="text-align:left">Source/Dest&lt;/th>
&lt;th style="text-align:left">Mem Type&lt;/th>
&lt;th style="text-align:left">Function&lt;/th>
&lt;/tr>
&lt;/thead>
&lt;tbody>
&lt;tr>
&lt;td style="text-align:left">Source&lt;/td>
&lt;td style="text-align:left">GMEM/SMEM&lt;/td>
&lt;td style="text-align:left">&lt;code>partition_S()&lt;/code>&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">Dest&lt;/td>
&lt;td style="text-align:left">GMEM/SMEM&lt;/td>
&lt;td style="text-align:left">&lt;code>partition_D()&lt;/code>&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">Source&lt;/td>
&lt;td style="text-align:left">Registers&lt;/td>
&lt;td style="text-align:left">&lt;code>retile_S()&lt;/code>&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">Dest&lt;/td>
&lt;td style="text-align:left">Registers&lt;/td>
&lt;td style="text-align:left">&lt;code>retile_D()&lt;/code>&lt;/td>
&lt;/tr>
&lt;/tbody>
&lt;/table>
&lt;p>&lt;strong>GMEM and SMEM are shared across threads&lt;/strong>, so they need to be sliced. The partitioner hands each thread its piece of the source/destination region. Registers are the opposite: each thread already owns its own set, so there&amp;rsquo;s no shared pool (&lt;a class="link" href="#registers-arent-memory" >Registers Aren&amp;rsquo;t Memory&lt;/a>). You don&amp;rsquo;t &lt;em>partition&lt;/em> a register tensor since there&amp;rsquo;s nothing to slice. You &lt;em>retile&lt;/em> it, because the register fragment was originally laid out for the MMA atom, and the copy atom may want a different layout in the same set of physical registers. &lt;code>retile_D/S&lt;/code> rebinds the layout without moving anything; it tells the copy atom which logical register goes where.&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: You&amp;rsquo;ll see this rule apply everywhere: Q/K SMEM-&amp;gt;register (above), V SMEM-&amp;gt;register, the output register-&amp;gt;SMEM and SMEM-&amp;gt;register staging.&lt;/p>
&lt;/blockquote>
&lt;h2 id="register-copy-and-mma">Register Copy and MMA&lt;/h2>
&lt;p>Unlike the GMEM-&amp;gt;SMEM transaction where we copy the whole tile in one go, we can pseudo-pipeline the fragment loads while performing the GEMM loop across dimension K. For the tiled MMA, we MMA over dim-K, loading the next tile fragment every iteration. This interleaves the &lt;code>ldmatrix&lt;/code> load with some compute and might save a bit of time due to memory controller and tensor core overlap (functional units can execute independently). But mainly, by explicitly telling the compiler when certain fragments need to be ready, we can conserve register pressure by only having them available when they are needed. In our case, if we prefetch the next block every iteration, we only really need two register fragments available at any time.&lt;/p>
&lt;h3 id="mma-shape">MMA Shape&lt;/h3>
&lt;p>The tiled MMA register tensors (&lt;code>tSsQ&lt;/code>, &lt;code>tSsK&lt;/code>) have shape &lt;code>(MMA, MMA_X, MMA_Y)&lt;/code> for a row-major tiling of shape &lt;code>(X, Y)&lt;/code> (see visualization in the &lt;a class="link" href="#fragment-reshape" >fragment reshape section&lt;/a>).&lt;/p>
&lt;ul>
&lt;li>&lt;code>MMA&lt;/code>: shape/number of elements per thread. For our tiled MMA, it&amp;rsquo;s 8 elements per thread for Q and 4 elements per thread for K, V, and the accumulator. The output of our SM80 16x8x16 atom has &lt;code>MMA=(2,2)&lt;/code>, which means each thread holds 4 values in the shape (2, 2).&lt;/li>
&lt;li>&lt;code>MMA_X&lt;/code> is the number of tiles along X and&lt;/li>
&lt;li>&lt;code>MMA_Y&lt;/code> is the number of tiles along Y. In this case, &lt;code>X=kBlockM&lt;/code> and &lt;code>Y=kHeadDim&lt;/code> for &lt;code>tSsQ&lt;/code>. By explicitly constructing the loop ourselves, we ensure the GEMM tiles across K for each output tile and that each warp holds all of the values of its output row tile.&lt;/li>
&lt;/ul>
&lt;h3 id="mma-loop-qkt-gemm">MMA Loop: QK^T GEMM&lt;/h3>
&lt;blockquote>
&lt;p>&lt;strong>Source:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/utils.cuh" target="_blank" rel="noopener"
>&lt;code>utils.cuh&lt;/code>&lt;/a> (&lt;code>gemm&lt;/code>)&lt;/p>
&lt;/blockquote>
&lt;p>We index these K-tiles via &lt;code>register(_, _, i)&lt;/code> to grab the relevant K-fragment per loop iteration. The TiledMMA handles the M and N dimensions.&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_macro.png"
width="2270"
height="1244"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_macro_hu04f74aa79072a560bd0ec98594996ed7_145696_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_macro_hu04f74aa79072a560bd0ec98594996ed7_145696_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_macro_hu04f74aa79072a560bd0ec98594996ed7_145696_2048x0_resize_lanczos_3.png 2048w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_macro.png 2270w"
loading="lazy"
alt="Macro view of the MMA. We iterate over the K-dimension, each tile multiplying across and summing to form one output tile. The colors just mean they pair, not that they are the same. In the tiled MMA, CuTe handles all the M, N work on our behalf. We just have to concatenate via the K-dim."
class="gallery-image"
data-flex-grow="182"
data-flex-basis="437px"
>
&lt;/p>
&lt;p>Here&amp;rsquo;s the full GEMM block:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// load initial Q, K fragments (0)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">smem_tiled_copy_Q&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tSsQ&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_0&lt;/span>&lt;span class="p">{}),&lt;/span> &lt;span class="n">tXrQ&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_0&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">smem_tiled_copy_K&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tSsK&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_0&lt;/span>&lt;span class="p">{}),&lt;/span> &lt;span class="n">tXrK&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_0&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// compile-time static, registers only live per iteration
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="cp">#pragma unroll
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span>&lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tSrQ&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="n">i&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// prefetch next Q, K block
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="k">if&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">i&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tSrQ&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">smem_tiled_copy_Q&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tSsQ&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">tXrQ&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">smem_tiled_copy_K&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tSsK&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">tXrK&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// MMA on frags
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">gemm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tiled_mma&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tSrQ&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">i&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">tSrK&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">i&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">acc_s&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="bank-conflicts-and-smem-layout">Bank Conflicts and SMEM Layout&lt;/h2>
&lt;p>Ok, we have to address the elephant in the room. I&amp;rsquo;ve gone this far without talking about the SMEM layout, which is critical if we don&amp;rsquo;t want to kill all of our performance with suboptimal SMEM access patterns. If we simply stored data in SMEM in the same format as GMEM, we would quickly run into serious memory-bound issues due to &lt;strong>bank conflicts.&lt;/strong> If you&amp;rsquo;ve made it this far, you hopefully know what these are already. However, if you don&amp;rsquo;t:&lt;/p>
&lt;blockquote>
&lt;p>Bank Conflict: when &lt;strong>multiple threads in the same warp&lt;/strong> simultaneously request memory within the same bank in shared memory but across distinct addresses, we say there is a bank conflict. &lt;a class="link" href="https://modal.com/gpu-glossary/perf/bank-conflict" target="_blank" rel="noopener"
>Source&lt;/a>&lt;/p>
&lt;/blockquote>
&lt;p>In order to enable highly parallel bandwidth in shared memory, NVIDIA stores the underlying data across 32 banks. For each warp, only one thread can ask for a value from the same bank per cycle. If two or more threads try to access the same bank at the same time, the memory controller has no choice but to serialize the transactions&amp;ndash;each thread takes its turn reading from memory. If 5 threads access bank 13 at the same time, the memory transaction will take &lt;em>5 times as long&lt;/em> as if they read 5 different banks.&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: As we described in our &lt;a class="link" href="#how-do-32-threads-load-128-bits-each" >four-transaction loading pattern&lt;/a>, the SMEM can at max fetch 4 bytes from 32 banks per transaction &amp;ndash; or 128 bytes of data. If we have a 2-way conflict, our poor memory controller has to load 2x128 bytes, where 16 threads only use half of the bytes (64 bytes, the rest is thrown away) per transaction.&lt;/p>
&lt;/blockquote>
&lt;p>It&amp;rsquo;s a hardware design choice influenced by power consumption, wiring, latency, and speed. If you somehow figured out how to access any piece of data in SMEM concurrently for free, then you should be instantly nominated for the Turing Award or sent straight to a psychiatric ward. Unfortunately, dealing with bank conflicts is just a part of GPU programming.&lt;/p>
&lt;p>Each of these 32 banks is 4 bytes wide&amp;ndash;consecutive 4-byte chunks are stored in consecutive banks. For example, in an fp32 array &lt;code>float x[] = [0.f, 1.f, 2.f, 3.f]&lt;/code>, 0 would be in bank 0, 1 in bank 1, etc. If you had 32 threads in a warp simultaneously accessing 32 float32s in tandem, you&amp;rsquo;d be accessing all 32 banks separately, which is conflict-free. This &amp;ldquo;ideal&amp;rdquo; use case is by design.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="kt">int&lt;/span> &lt;span class="n">bank&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">byte_address&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>However, much of the time, we aren&amp;rsquo;t just linearly traversing our data. Sometimes threads work across rows, columns, or both. Let&amp;rsquo;s go back to our fp32 example. Imagine we have a 32x32 row-major float matrix and we want to add 1 to each element. One reasonable approach is to have one warp traverse the columns in lock-step:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="cp">#pragma unroll
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span>&lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">j&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">j&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">j&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// each thread traverses one row
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="c1">// each warp is hence one column per cycle
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">smem&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">thread_idx&lt;/span>&lt;span class="p">][&lt;/span>&lt;span class="n">j&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="mf">1.0f&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>In this example, at &lt;code>j=0&lt;/code>, thread 0 accesses &lt;code>(0, 0)&lt;/code>, thread 1 accesses &lt;code>(1, 0)&lt;/code>, &amp;hellip;, and thread 31 accesses &lt;code>(31, 0)&lt;/code>. Since our SMEM array is 32x32, the row stride increments by 32 floats&amp;ndash;32 words/4-byte numbers, or 32 banks. This means all 32 threads access bank 0 on the same cycle, for all 32 elements in the column. This is the ultimate 32-way bank conflict that causes a 32x slowdown. It doesn&amp;rsquo;t matter how optimized the rest of your kernel is&amp;ndash;this access pattern will absolutely destroy your performance.&lt;/p>
&lt;p>In this case, the fix is simple. We can have the warp iterate over one row per cycle, which is 32 contiguous elements = 32 consecutive banks&amp;ndash;no conflict, no problems.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="cp">#pragma unroll
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span>&lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">i&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">smem&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">i&lt;/span>&lt;span class="p">][&lt;/span>&lt;span class="n">thread_idx&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="mf">1.0f&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>If for some reason you cannot simply &amp;ldquo;traverse the rows&amp;rdquo;, there are two other common patterns.&lt;/p>
&lt;h3 id="padding">Padding&lt;/h3>
&lt;p>If you&amp;rsquo;ve ever worked with an image processing pipeline or CNNs, this kind of padding is precisely the same concept. If you&amp;rsquo;ve ever worked with non-power-of-two shapes in deep learning, I&amp;rsquo;m sure you&amp;rsquo;ve padded your weights or inputs because powers of two are nicer to the kernels.&lt;/p>
&lt;p>Funnily enough, with SMEM padding we often try to &lt;em>break&lt;/em> these power-of-two symmetries to improve our bank access patterns.&lt;/p>
&lt;p>Going back to our example, the reason we end up with bank conflicts is that our row stride is a multiple of our 32-bank, 4-byte cycle. Every address separated by 128 bytes maps to the same bank. So a column-major access pattern for a 32x32 float array is an absolute death sentence. This wouldn&amp;rsquo;t be any better for 32x64, 32x96, or 32x1024 float arrays either, because the column width in each case is a multiple of 128 bytes.&lt;/p>
&lt;p>We can break this 128-byte stride pattern simply by padding each row with an extra float. So instead of 32x32, we now force our SMEM to have shape 32x33. Our SMEM chunk occupies 32 more bytes with one dummy float per row, but our column access pattern no longer suffers from bank conflicts. If we look at our column access pattern from before, at &lt;code>j=0&lt;/code>, thread 0 still accesses (0, 0), thread 1 still accesses (1, 0), &amp;hellip;, and thread 31 still accesses (31, 0). But each row stride is now &amp;ldquo;33&amp;rdquo; banks apart, so thread 0 accesses memory address 0, while thread 1 now accesses address 33, not 32. So in one cycle, thread 0 accesses bank 0, thread 1 accesses bank 1, &amp;hellip;, and thread 31 accesses bank 31. On the next iteration, we shift by 1 bank, where thread 0 accesses bank 1 and thread 1 accesses bank 2. We are now conflict-free, at the expense of 32 &amp;ldquo;empty&amp;rdquo; floats.&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/padding.png"
width="2302"
height="1296"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/padding_hu3e502985a7792c47736837d91aa89839_287968_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/padding_hu3e502985a7792c47736837d91aa89839_287968_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/padding_hu3e502985a7792c47736837d91aa89839_287968_2048x0_resize_lanczos_3.png 2048w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/padding.png 2302w"
loading="lazy"
alt="Example of two-way column conflict access and a simple padding solution."
class="gallery-image"
data-flex-grow="177"
data-flex-basis="426px"
>
&lt;/p>
&lt;p>When you aren&amp;rsquo;t constrained by SMEM limits, padding is often a very simple and worthwhile tradeoff. It&amp;rsquo;s easy to implement as long as you match your strides correctly and load/write from SMEM following your new padding rules. However, if you&amp;rsquo;re dealing with complex memory access patterns or different data types (a long is 2 banks wide, 2 halfs fit in one bank), padding might be too complicated or completely insufficient for your use case.&lt;/p>
&lt;h3 id="swizzling">Swizzling&lt;/h3>
&lt;blockquote>
&lt;p>&lt;strong>Play:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/swizzle_sim.py" target="_blank" rel="noopener"
>&lt;code>scratch/swizzle_sim.py&lt;/code>&lt;/a> (pure-Python &lt;code>Swizzle&amp;lt;B,M,S&amp;gt;&lt;/code> simulator, toy with the bit math)&lt;/p>
&lt;/blockquote>
&lt;p>This is precisely the problem in FA2. We have some copy-atom- and MMA-specific read/write access patterns and we&amp;rsquo;re working with 16-bit halfs, which make padding unattractive if not impossible. Swizzling comes to the rescue.&lt;/p>
&lt;p>Swizzling is your answer to the brilliant thought: &amp;ldquo;what if our access patterns magically happened to use different banks?&amp;rdquo; Using some bit magic, swizzling rearranges the mapping of data elements in shared memory to avoid bank conflicts.&lt;/p>
&lt;p>Back to our example. For our column access pattern on the 32x32 array, we &amp;ldquo;reinterpret&amp;rdquo; our SMEM so that address 0 is bank 0, address 32 is bank 1, &amp;hellip;, and address 31*32 is bank 31. It&amp;rsquo;s a scrambler (or swizzler, if you will) that maps your (i, j) to a true address under the hood such that your bank conflicts magically disappear. Before each write and read to SMEM, we swizzle the incoming access (i, j) and translate it to a physical address (or vice versa), so that even though we think we&amp;rsquo;re writing (1, 3) to memory location $32\cdot 1 + 3$, we&amp;rsquo;re actually writing it to some swizzled address under the hood. The writer and consumer are none the wiser. As long as it writes (1, 3) and gets back the same (1, 3), it doesn&amp;rsquo;t care.&lt;/p>
&lt;blockquote>
&lt;p>Think of it like a valet attendant. You give your keys to the guy up front, and he parks your car somewhere in the garage. When you come back from your day of disappointing your family, you simply ask for your car back. They fetch it, you get in, and you leave. You don&amp;rsquo;t care whether it was on floor 1 or floor 9001&amp;ndash;you just care that you got your car back.&lt;/p>
&lt;/blockquote>
&lt;p>There are likely an infinite number of ways to scramble addresses, but we have to meet a few criteria:&lt;/p>
&lt;ol>
&lt;li>Addresses or indices must have a 1-1 mapping. Each (i, j) has to have a unique physical location in memory.&lt;/li>
&lt;li>It must be fast and deterministic.&lt;/li>
&lt;li>If you are reading or writing N bytes, those N bytes still have to be contiguous in memory. Your data might be fp16, but you might be reading 8 fp16s at once. Those 128 bits/addresses must still be contiguous in the swizzled domain. Even though you could technically split those 128 bits into 4-byte bank chunks and distribute them throughout memory, the logic becomes way more convoluted and you likely lose vectorization or cache performance.&lt;/li>
&lt;/ol>
&lt;p>Swizzling accomplishes this with a bit of clever bit arithmetic. It uses the XOR operation, which satisfies the three conditions above in the following way:&lt;/p>
&lt;ol>
&lt;li>&lt;code>a xor b&lt;/code> is bijective. For any &lt;code>a xor b&lt;/code>, changing either &lt;code>a&lt;/code> or &lt;code>b&lt;/code> changes the output.&lt;/li>
&lt;li>XOR bit instructions are as fast as you can get and are fully deterministic. XOR also preserves cardinality, so any a and b of n-bits cannot give an output greater than n-bits.&lt;/li>
&lt;li>We can ignore the LSB bits that hold the contiguous chunks and XOR the &amp;ldquo;contiguous addresses&amp;rdquo; on top. For example, if we&amp;rsquo;re loading 8 fp16s, we can treat bytes 0-15 as address 0, since we copy those bytes in one go.&lt;/li>
&lt;/ol>
&lt;table>
&lt;thead>
&lt;tr>
&lt;th style="text-align:center">Input A&lt;/th>
&lt;th style="text-align:center">Input B&lt;/th>
&lt;th style="text-align:center">Output (A ⊕ B)&lt;/th>
&lt;/tr>
&lt;/thead>
&lt;tbody>
&lt;tr>
&lt;td style="text-align:center">0&lt;/td>
&lt;td style="text-align:center">0&lt;/td>
&lt;td style="text-align:center">0&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:center">0&lt;/td>
&lt;td style="text-align:center">1&lt;/td>
&lt;td style="text-align:center">1&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:center">1&lt;/td>
&lt;td style="text-align:center">0&lt;/td>
&lt;td style="text-align:center">1&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:center">1&lt;/td>
&lt;td style="text-align:center">1&lt;/td>
&lt;td style="text-align:center">0&lt;/td>
&lt;/tr>
&lt;/tbody>
&lt;/table>
&lt;p>So how do we actually apply this XOR? It&amp;rsquo;s miraculously simple:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">Swizzle&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">row&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">col&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">row&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">row&lt;/span> &lt;span class="o">^&lt;/span> &lt;span class="n">col&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Why does this work? Let&amp;rsquo;s examine our float example. We access &lt;code>(0...31, 0)&lt;/code> then &lt;code>(0...31, 1)&lt;/code> and so on. For column 0, &lt;code>n ^ 0 = n&lt;/code>. This means our outputs map to (0&amp;hellip;31, 0&amp;hellip;31). Since each row starts at a different bank, we adequately diversify across all 32 banks. For the other columns, we&amp;rsquo;ve shown that &lt;code>a ^ b&lt;/code> is unique for any fixed &lt;code>b=col&lt;/code>, so we are guaranteed to hit all 32 banks for all 32 threads. Neat! If you&amp;rsquo;re unconvinced, try a few column examples yourself.&lt;/p>
&lt;p>Let&amp;rsquo;s visualize where each float ends up. The number of each square represents the column it originally belonged to. The color points to where it was originally.&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle-32x32.png"
width="3114"
height="1577"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle-32x32_hu6596b57894f6490cbf0e273c8793060d_380231_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle-32x32_hu6596b57894f6490cbf0e273c8793060d_380231_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle-32x32_hu6596b57894f6490cbf0e273c8793060d_380231_2048x0_resize_lanczos_3.png 2048w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle-32x32.png 3114w"
loading="lazy"
alt="32x32 swizzle pattern"
class="gallery-image"
data-flex-grow="197"
data-flex-basis="473px"
>
&lt;/p>
&lt;p>Ok, this is kind of hard to look at. Let&amp;rsquo;s look at an 8x8 example for more clarity on where each column ends up:&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle-8x8.png"
width="1749"
height="980"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle-8x8_hu6a7db30d722cbb431d4322770efc8b0b_91421_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle-8x8_hu6a7db30d722cbb431d4322770efc8b0b_91421_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle-8x8.png 1749w"
loading="lazy"
alt="8x8 swizzle pattern"
class="gallery-image"
data-flex-grow="178"
data-flex-basis="428px"
>
&lt;/p>
&lt;p>We can now see that each element of each column ends up in a different bank. Any column access pattern now hits every 32 bank in its 128-byte glory. XOR interleaves our elements with this beautiful diagonal butterfly pattern, which you can see the best in the 32x32 grid.&lt;/p>
&lt;blockquote>
&lt;p>This XOR technique works great, but it&amp;rsquo;s not exactly trivial as to why it is the default option. Part of it seems like divine benevolence, which is probably true, but the short answer is that it&amp;rsquo;s fast, it works, and it&amp;rsquo;s an access pattern no normal kernel engineer would use in almost any situation. It isn&amp;rsquo;t foolproof and may need to be combined with padding or different access patterns; more complex multidimensional kernels typically employ even more complex swizzling patterns. This article shows in more detail why XOR works: &lt;a class="link" href="https://leimao.github.io/blog/CuTe-Swizzle/" target="_blank" rel="noopener"
>https://leimao.github.io/blog/CuTe-Swizzle/&lt;/a>&lt;/p>
&lt;/blockquote>
&lt;h3 id="swizzling-fa2">Swizzling FA2&lt;/h3>
&lt;blockquote>
&lt;p>&lt;strong>Source:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/kernel_traits.cuh" target="_blank" rel="noopener"
>&lt;code>kernel_traits.cuh&lt;/code>&lt;/a>&lt;/p>
&lt;p>&lt;strong>Play:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/bench_swizzle_writes.cu" target="_blank" rel="noopener"
>&lt;code>scratch/bench_swizzle_writes.cu&lt;/code>&lt;/a> (benchmark cp.async / STS.128 with vs. without swizzle), &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/swizzle_layouts.cu" target="_blank" rel="noopener"
>&lt;code>scratch/swizzle_layouts.cu&lt;/code>&lt;/a> (print FA2 SMEM layouts)&lt;/p>
&lt;/blockquote>
&lt;p>The fp32 example was quite trivial. Our FA2 pattern is slightly more complex, as we have to deal with tiled copy patterns, MMA atom layouts, and vectorized loads. As a result, we have to redefine what &amp;ldquo;row&amp;rdquo; and &amp;ldquo;column&amp;rdquo; mean via the Swizzle Atom in CuTe.&lt;/p>
&lt;p>We have two interactions with SMEM: GMEM-&amp;gt;SMEM write and SMEM-&amp;gt;register read.&lt;/p>
&lt;h4 id="gmem-smem-write-requirements">GMEM-&amp;gt;SMEM Write Requirements&lt;/h4>
&lt;p>Recall the &lt;a class="link" href="#how-do-32-threads-load-128-bits-each" >four 128-byte transactions&lt;/a> that drain a warp-wide cp.async: each transaction writes a contiguous 128 bytes (32 bank accesses) and must hit all 32 banks for optimal performance. Since that vectorized write is conflict-free by default, any swizzle must happen &lt;em>on top of&lt;/em> the 128-byte contiguous chunks (8 halfs). Everything else is fair game.&lt;/p>
&lt;p>Since we have the flexibility to load 128-byte contiguous chunks, we don&amp;rsquo;t even need to swizzle this transaction. We just have to make sure that if we do swizzle SMEM, we keep each 8-half block contiguous in memory.&lt;/p>
&lt;h4 id="smem-registers-read-requirements">SMEM-&amp;gt;Registers Read Requirements&lt;/h4>
&lt;p>Our SMEM-&amp;gt;register transaction occurs during our SMEM tiled copy. Each thread is still loading 32-bits x 4 = 128 bits, like in our GMEM atom. For the GMEM load, all we had to do was load the entire tile, so we could choose 128-byte contiguous chunks to avoid bank conflicts. We can&amp;rsquo;t do this for SMEM, since the read pattern depends on the shape of the MMA fragments.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">using&lt;/span> &lt;span class="n">SmemCopyAtom&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Copy_Atom&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">SM75_U32x4_LDSM_N&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">half_t&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">using&lt;/span> &lt;span class="n">TiledMmaAtom&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MMA_Atom&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">SM80_16x8x16_F32F16F16F32_TN&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>If we don&amp;rsquo;t swizzle the SMEM layout, we&amp;rsquo;d simply have the layout &lt;code>(kBlockM, kHeadDim)&lt;/code>. Each &lt;code>MMA_Atom&lt;/code> would tile using a 16x16 chunk out of our SMEM per A-fragment (or 16x8 for B- or C-fragments). Let&amp;rsquo;s examine the bank conflict:&lt;/p>
&lt;p>As before, banks cycle every 128 bytes, which is 32 consecutive floats or 64 halfs. If we have &lt;code>kHeadDim=64&lt;/code>, then we have conflicts for any threads that touch the same column in one load cycle. For a 16x16 fragment (per warp) load using a 32x4 copy atom (per thread), we notice that these byte sizes are equal, so each copy atom loads one 16x16 A-fragment and two 16x8 B/C-fragments. Ideally, we want this load to drain through the same optimal &lt;a class="link" href="#how-do-32-threads-load-128-bits-each" >four 128-byte transactions&lt;/a> as our GMEM load, but this assumes we&amp;rsquo;re conflict free.&lt;/p>
&lt;p>Let&amp;rsquo;s analyze our conflict pattern. Since we touch 16 rows, our load touches 16 values in a column, which causes 16-way conflict. Since we technically only load a quarter-warp at a time (128-bytes), we touch at max 8 rows at a time, so an 8-way bank conflict. Regardless, that&amp;rsquo;s an 8x slowdown just from an SMEM load.&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Aside&lt;/strong>: I&amp;rsquo;m not 100% sure what the quarter-warp slice actually looks like. It could be consecutive groups of 8 or some other pattern. In our SMEM copy, our quarter-warp transaction could technically be a 8 row x 1 column load OR a 4 row x 2 column load, which would be an 8-way or 4-way conflict depending on what the default actually is. You can certainly find out by running &lt;code>nsight&lt;/code> on a non-swizzled SMEM copy. Regardless, we don&amp;rsquo;t want a 4x or an 8x slowdown. Our swizzle fix will resolve the conflicts no matter how the quarter-warps are sliced.&lt;/p>
&lt;/blockquote>
&lt;p>Okay, let&amp;rsquo;s fix this problem then. We could use padding, but we&amp;rsquo;ll see how that becomes infeasible with our constraints. For the A fragment, we&amp;rsquo;d need to shift each row&amp;rsquo;s banks by 16 floats or 32 halfs, so row 0 accesses 0-15, row 1 accesses 16-31, and so on. This increases our memory footprint by &lt;code>32*kBlockM&lt;/code> halfs, which is a 50% increase over &lt;code>kHeadDim=64&lt;/code> &amp;ndash; we&amp;rsquo;re running pretty tight on SMEM and this isn&amp;rsquo;t tolerable.&lt;/p>
&lt;p>So our best option is to swizzle. We need to keep the bottom 8 halfs intact, which means for some fp16 address A, we mask out the bottom 3 bits since they must be contiguous for an aligned fp16 swizzle block. What are our row and column? The row is simply the row of SMEM. In our example, each row is 64 halfs, so for fp16 address A the row is all the bits beyond the first six, i.e. &lt;code>A &amp;gt;&amp;gt; 6&lt;/code>. The column is the bits in between our contiguous chunk and our row. With 64 columns and 8 halfs per chunk, we have 8 8-half columns, which become the 3 bits sitting between the row bits and the bottom 3 chunk bits.&lt;/p>
&lt;p>CuTe defines this parameterization with the Swizzle struct:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">Swizzle&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">B&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">M&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">S&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="n">swizzle&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;ul>
&lt;li>B: column bits; after we&amp;rsquo;ve removed the mask bits, how many bits represent the columns? For us, it&amp;rsquo;s 3.&lt;/li>
&lt;li>M: mask of LSB bits you want to keep contiguous. We want 8 contiguous halfs, so 3 LSB bits.&lt;/li>
&lt;li>S: shift bits; how many bits to the &amp;ldquo;left&amp;rdquo; of the mask represent which row we&amp;rsquo;re at? For our case, the row bits sit beyond bit 6, so &lt;code>S=6-M=6-3=3&lt;/code>.&lt;/li>
&lt;/ul>
&lt;blockquote>
&lt;p>For our 32x32 float example, let&amp;rsquo;s compute B, M, S. We only look at one float at a time, so &lt;code>M=0&lt;/code>. We have 32 columns/floats per row, so &lt;code>B=log2(32)=5&lt;/code>. Finally, our row bits are just all the bits above the columns, so &lt;code>S=B=5&lt;/code>. Since we only have 32 rows, we&amp;rsquo;ll only ever have 5 row bits as well, but Swizzle doesn&amp;rsquo;t need to know that&amp;ndash;our swizzle pattern just computes the translation, and we&amp;rsquo;re responsible for providing it the relevant SMEM pointers.&lt;/p>
&lt;/blockquote>
&lt;p>Notice that the B and S bits can actually overlap. In most scenarios, they don&amp;rsquo;t. There may be some behavior you can exploit with this overlap, but more often, the B and S bits don&amp;rsquo;t need to be adjacent. In our case, our row and column bits &lt;em>are&lt;/em> adjacent, so &lt;code>B=S&lt;/code>. For different strides or certain layouts, this split gives us flexibility to ensure our swizzles point to the correct bits.&lt;/p>
&lt;p>So our swizzle atom is simply &lt;code>Swizzle&amp;lt;3, 3, 3&amp;gt;{}&lt;/code>.&lt;/p>
&lt;h4 id="kblocksmem">kBlockSmem&lt;/h4>
&lt;p>You&amp;rsquo;ll see this swizzle pattern a lot for fp16, since the bank-conflict repeat cycle occurs at 64-half intervals, so it often makes sense to structure your SMEM such that each row covers all 32 banks. For FA2, most kernels opt for &lt;code>hdim=32, 64, 128&lt;/code>. For &lt;code>hdim=128&lt;/code> we&amp;rsquo;d have to redo all of the swizzling math for this new column size, so instead we can set a &lt;code>kBlockSmem&lt;/code> capped at 64, which lets us use one swizzle atom for everything. This means less templating for kernel-size definitions, nothing more. If you wanted to recompute the shapes and swizzling for larger hdims, you&amp;rsquo;re perfectly welcome to.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">static&lt;/span> &lt;span class="k">constexpr&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">kBlockKSmem&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">kHeadDim&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="mi">64&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">?&lt;/span> &lt;span class="mi">64&lt;/span> &lt;span class="o">:&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;blockquote>
&lt;p>For &lt;code>hdim=32&lt;/code>, you still have to redeclare some things, for example &lt;code>B=2&lt;/code> for the swizzle atom. I bring this stipulation up because it&amp;rsquo;s the path the FA2 source code took. It&amp;rsquo;s not the only implementation and not necessarily the best one&amp;ndash;it just might be a point of confusion when reading their &lt;code>kernel_traits.h&lt;/code> definition. We&amp;rsquo;ll cover another huge stipulation in our &lt;a class="link" href="#svtnoswizzle-the-no-op-nobody-caught" >V-fragment section&lt;/a>.&lt;/p>
&lt;/blockquote>
&lt;h4 id="swizzle-composition">Swizzle Composition&lt;/h4>
&lt;p>Now let&amp;rsquo;s actually make the SMEM layout. Since we have a swizzle and the actual SMEM dimensions, our resulting &lt;code>SmemLayout&lt;/code> is a tiled layout&amp;ndash;we have to tile the swizzle on top of the underlying memory. We first create our tile atom and then tile the atom to our SMEM shape.&lt;/p>
&lt;p>The swizzle atom relies on a composition of a swizzle and the layout underneath. The layout provides the raw coordinates/address to the swizzler, so that B, M, S actually mean something. Our swizzle atom is &lt;code>Swizzle&amp;lt;3, 3, 3&amp;gt;&lt;/code>, and our layout underneath is the SMEM subsection we&amp;rsquo;re actually scrambling. From our analysis earlier, it has 8 rows and spans the entire column width, so that each 32x4 &lt;code>LDSM_N&lt;/code>/16x8x16 MMA tile load becomes bank-conflict-free. Therefore, the layout has shape &lt;code>(8, kBlockSmem)&lt;/code> and stride &lt;code>(kBlockSmem, 1)&lt;/code>.&lt;/p>
&lt;p>We use the &lt;code>composition(f1, f2)&lt;/code> function, which composes the layouts as &lt;code>f1(f2(x))&lt;/code>. The raw coordinates are translated into the unswizzled address, which gets fed to the swizzler&amp;ndash;therefore &lt;code>f1=Swizzle&lt;/code> and &lt;code>f2=Layout&lt;/code>. We apply this atom to our overall SMEM shape; for Q, this is &lt;code>(kBlockM, kBlockSmem)&lt;/code>.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">using&lt;/span> &lt;span class="n">SmemLayoutAtomQ&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="k">decltype&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">composition&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Swizzle&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockKSmem&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockKSmem&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">{})&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">SmemLayoutQ&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">tile_to_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">SmemLayoutAtomQ&lt;/span>&lt;span class="p">{},&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockM&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockSmem&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">{}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>We can finally replace the layout we used to make &lt;code>sQ&lt;/code> above. &lt;code>sK&lt;/code> and &lt;code>sV&lt;/code> are an exercise left to the reader.&lt;/p>
&lt;h2 id="dealing-with-v-copies">Dealing with V Copies&lt;/h2>
&lt;p>V is a slightly different beast, since it doesn&amp;rsquo;t follow the row-major loading pattern of Q and K during &lt;code>O=S@V&lt;/code>. When we compute our attention scores S, the resulting shape is &lt;code>(kBlockM, kBlockN)&lt;/code>. Since V is of shape &lt;code>(kBlockN, kHeadDim)&lt;/code>, we have to transpose V, since our original copy/MMA pattern expects the concatenation dim to be the second shape dimension. As a result, we have to make transpose-view tensors for V&amp;rsquo;s SMEM layouts to make sure the copies and fragments are correct.&lt;/p>
&lt;h3 id="v-gmem-smem">V: GMEM-&amp;gt;SMEM&lt;/h3>
&lt;p>To get maximum coalesced-vectorized load performance, we can simply copy V in its row-major form from GMEM to SMEM. We need to eventually transpose V before it hits the register fragments, and Ampere and Turing (SM75+) fortunately provide some transposed &lt;code>ldmatrix&lt;/code> instructions that do so for us. As a result, we only have to worry about the transpose once we hit the SMEM-&amp;gt;register stage. The GMEM-&amp;gt;SMEM copy fully mirrors the tiled copy for K from earlier:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">mV&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_gmem_ptr&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="k">reinterpret_cast&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="k">const&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">half_t&lt;/span> &lt;span class="o">*&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">v_ptr&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">batch_idx&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">v_batch_stride&lt;/span> &lt;span class="o">+&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">head_idx&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">v_head_stride&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">seqlen_k&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">head_dim&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_stride&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">v_row_stride&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">gV&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">local_tile&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">mV&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockN&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kHeadDim&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{}),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_coord&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">sV&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sK&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sK&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Traits&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">SmemLayoutKV&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// (VCPY, VCPY_N, VCPY_K, nblocksN)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tVgV&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">gmem_thr_copy_QKV&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_S&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">gV&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tVsV&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">gmem_thr_copy_QKV&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_D&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sV&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="v-smem-register">V: SMEM-&amp;gt;Register&lt;/h3>
&lt;p>This is the step where we have to tread a bit carefully. V is sitting in SMEM in the same format as Q and K&amp;ndash;contiguous along &lt;code>kHeadDim&lt;/code>&amp;ndash;so we can&amp;rsquo;t just copy our SMEM-&amp;gt;register pipeline from earlier. This part is a bit confusing, so let&amp;rsquo;s visualize the problem first:&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_v_layout.png"
width="1920"
height="1080"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_v_layout_hu0ed971fc0e6f32098603b2d6f2c28536_106126_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_v_layout_hu0ed971fc0e6f32098603b2d6f2c28536_106126_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_v_layout.png 1920w"
loading="lazy"
alt="SMEM physical layout for our MMAs. Each block is a tile."
class="gallery-image"
data-flex-grow="177"
data-flex-basis="426px"
>
&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: The tiled MMA visualization in our &lt;a class="link" href="#tiled-mma" >tiled MMA section&lt;/a> was a human-friendly view that&amp;rsquo;s actually what &lt;code>sV&lt;/code> looks like here&amp;ndash;it doesn&amp;rsquo;t represent the physical SMEM layout we have.&lt;/p>
&lt;/blockquote>
&lt;p>As we can see, Q and K are both row-contiguous along &lt;code>kHeadDim&lt;/code>, which they matmul across. S and V matmul across &lt;code>kBlockN&lt;/code>, not &lt;code>kHeadDim&lt;/code>, so V is not row-contiguous along the concatenation dimension. As a result, we have to tile it &amp;ldquo;vertically&amp;rdquo; along the columns for the tiled MMA.&lt;/p>
&lt;p>But what does &amp;ldquo;vertically&amp;rdquo; even mean? We were pretty hand-wavy about the &lt;code>SM75_U32x4_LDSM_N&lt;/code> atom earlier, so let&amp;rsquo;s clarify it now:&lt;/p>
&lt;h4 id="ldsm-copy-atom">LDSM Copy Atom&lt;/h4>
&lt;p>When we issue the &lt;code>LDSM_N&lt;/code> instruction, we load the entire 16x16 fragment in one go. Since we have 8-half contiguous blocks for the 128-bit load (which is also our SMEM layout), each 16x16 &lt;code>ldmatrix&lt;/code> load takes in exactly 32 SMEM addresses&amp;ndash;one from each thread. For the &lt;code>U32x4&lt;/code> denomination, we treat each one as a 4x32-bit register load, which is perfect since the MMA Atom underneath understands that each 32-bit register holds two fp16s.&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: Ampere tensor cores support &lt;code>INT4&lt;/code>, &lt;code>INT8&lt;/code>, &lt;code>FP16&lt;/code>, &lt;code>BF16&lt;/code>, and &lt;code>TF32&lt;/code> (19 bits). The &lt;code>U32x4&lt;/code> modifier can be confusing because it only describes the source data/register size, not the actual data type underneath. The &lt;code>LDSM&lt;/code> instructions are purpose-made for tensor core MMAs and don&amp;rsquo;t support standard 32-bit floats.&lt;/p>
&lt;/blockquote>
&lt;p>The &lt;code>LDSM_T&lt;/code> instruction is slightly different because the contiguous data is meant to be transposed before hitting the registers. In the &lt;a class="link" href="#tiled-mma" >thread layout diagram&lt;/a> from earlier, we see how each thread holds its data in two halfs, each packed into one 32-bit register. During the &lt;code>LDSM_N&lt;/code> instruction, it can split the 128 bits (8 halfs) into 4 32-bit chunks directly, since the two fp16s are contiguous. However, for the transposed view, two contiguous fp16s no longer belong to the same thread&amp;rsquo;s register. One fp16 in a 32-bit chunk goes to half of one register and the other goes to the half of another&amp;rsquo;s to fit the MMA Atom&amp;rsquo;s expectation:&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/ldsm_reg.png"
width="2340"
height="1312"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/ldsm_reg_hu5c1a81e20100c08caf0f5fe4f1c1f005_185553_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/ldsm_reg_hu5c1a81e20100c08caf0f5fe4f1c1f005_185553_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/ldsm_reg_hu5c1a81e20100c08caf0f5fe4f1c1f005_185553_2048x0_resize_lanczos_3.png 2048w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/ldsm_reg.png 2340w"
loading="lazy"
alt="LDSM_N vs. LDSM_T SMEM-&amp;gt;register movement following the 16x8x16 MMA atom layout."
class="gallery-image"
data-flex-grow="178"
data-flex-basis="428px"
>
&lt;/p>
&lt;p>In SMEM for Q and K, contiguous fp16s end up in the same thread&amp;rsquo;s register since they&amp;rsquo;re contiguous in memory. In SMEM for V, the data is transposed relative to what the MMA expects, so contiguous fp16s belong to the column. &lt;code>LDSM_T&lt;/code> must move them to different destination registers. We use the &lt;code>SM75_U16x8_LDSM_T&lt;/code> atom to load the V fragments. This instruction reinterprets 4x32-bits as 8x16-bits, so the atom knows that each 128-bit chunk is 8 fp16s that need to be transposed into the registers. Since SMEM for V is still stored contiguously in 8-half chunks, we can still vectorize across the entire 128 bits.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">using&lt;/span> &lt;span class="n">SmemCopyAtomTransposed&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Copy_Atom&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">SM75_U16x8_LDSM_T&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">half_t&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h4 id="v-smem-partition">V SMEM Partition&lt;/h4>
&lt;p>The tiled MMA already handles the tiling strategy for our SMEM copy, but we have to transpose the &lt;code>sV&lt;/code> layout to match our new transposed atom and fit our tiling strategy from the &lt;a class="link" href="#v-smem-register" >visualization&lt;/a>. If we tiled &lt;code>sV&lt;/code> as-is, the atom tile and each thread tile would grab the wrong &amp;ldquo;next block&amp;rdquo; of SMEM:&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/v_tile.png"
width="2342"
height="1246"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/v_tile_hu17bc38c7333ce895cd68441321bdfb02_211033_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/v_tile_hu17bc38c7333ce895cd68441321bdfb02_211033_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/v_tile_hu17bc38c7333ce895cd68441321bdfb02_211033_2048x0_resize_lanczos_3.png 2048w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/v_tile.png 2342w"
loading="lazy"
alt="Thread tiling strategy on sV as-is. It grabs the next 128 bits along kHeadDim instead of along kBlockN. Note the actual SMEM is swizzled."
class="gallery-image"
data-flex-grow="187"
data-flex-basis="451px"
>
&lt;/p>
&lt;p>Remember, we want our SMEM to &amp;ldquo;look like&amp;rdquo; &lt;code>(kHeadDim, kBlockN)&lt;/code>, so we can simply do a layout transpose to reinterpret V SMEM&amp;rsquo;s underlying data. The easiest way to do this is to compose another layout on top of our swizzled KV atom from before. Since the &lt;code>composition()&lt;/code> function goes right to left, we want the most surface addressing on the very right.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">using&lt;/span> &lt;span class="n">SmemLayoutVt&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="k">decltype&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">composition&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">SmemLayoutKV&lt;/span>&lt;span class="p">{},&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kHeadDim&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockN&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockN&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">)));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// FA2 source code uses GenRowMajor{} for stride
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// They&amp;#39;re not consistent with their stride defs haha
&lt;/span>&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Our stride is row-major because in this layout interpretation we want our &lt;code>LDSM&lt;/code> tile and layout tile to move along the concat dim &lt;code>kBlockN&lt;/code> the fastest. Our full layout is therefore:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-fallback" data-lang="fallback">&lt;span class="line">&lt;span class="cl">composition([tile_to_shape(Swizzle-&amp;gt;Layout), (kBlockN, kHeadDim)],
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (kHeadDim, kBlockN))
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>The reason we add a composition on top of &lt;code>SmemLayoutKV&lt;/code> instead of calling &lt;code>tile_to_shape&lt;/code> directly on our &lt;code>(kHeadDim, kBlockN)&lt;/code> shape is that our swizzle pattern is specifically configured around &lt;code>kHeadDim&lt;/code>. We could configure a swizzle for &lt;code>kBlockN&lt;/code> similarly, but it would require more swizzle tiling math for this transposed view. Instead, we can just apply the composition, which does all of the translation for us with no extra work. It&amp;rsquo;s the most efficient way to use what we&amp;rsquo;ve already derived and takes advantage of the fact that K and V have the same tile dimensions in memory.&lt;/p>
&lt;p>Let&amp;rsquo;s apply this layout view to a new transposed &lt;code>sV&lt;/code> and finalize our copy:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="p">...&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">sV&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sK&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sK&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">SmemLayoutKV&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// t for transposed
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="n">sVt&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sV&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">SmemLayoutVt&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// tiled copy defs
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">smem_tiled_copy_V&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_tiled_copy_B&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">SmemCopyAtomTransposed&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">tiled_mma&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">smem_thr_copy_V&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">smem_tiled_copy_V&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">get_thread_slice&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tid&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// partition SMEM via our transposed SMEM
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// varname: thread, Output, sVt
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">tOsVt&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">smem_thr_copy_V&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_S&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sVt&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>As you can see, we use &lt;code>sV&lt;/code> for our GMEM tiled copy since we preserve the row-major shape. Only when we copy from SMEM-&amp;gt;registers do we transpose the SMEM view for &lt;code>LDSM_T&lt;/code>.&lt;/p>
&lt;h3 id="svtnoswizzle-the-no-op-nobody-caught">&lt;code>sVtNoSwizzle&lt;/code>: The No-Op Nobody Caught&lt;/h3>
&lt;blockquote>
&lt;p>&lt;strong>Source:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/kernel_traits.cuh" target="_blank" rel="noopener"
>&lt;code>kernel_traits.cuh&lt;/code>&lt;/a>&lt;/p>
&lt;p>&lt;strong>Play:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/v_fragment_test.cu" target="_blank" rel="noopener"
>&lt;code>scratch/v_fragment_test.cu&lt;/code>&lt;/a> (minimal repro: swap &lt;code>sVt&lt;/code> for &lt;code>sVtNoSwizzle&lt;/code>, see nothing break), &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/swizzle_layouts.cu" target="_blank" rel="noopener"
>&lt;code>scratch/swizzle_layouts.cu&lt;/code>&lt;/a> (print the layouts side by side)&lt;/p>
&lt;/blockquote>
&lt;blockquote>
&lt;p>&lt;strong>Tip&lt;/strong>: I recommend skipping this section if you&amp;rsquo;re trying to implement a working FA2 first. Come back when you have nothing left to lose.&lt;/p>
&lt;/blockquote>
&lt;p>Oh man, time to deal with the most frustrating line in the entire repo. Frustrating because it&amp;rsquo;s so simply declared, not explained, leads you down multiple rabbit holes, only for you to realize it literally does nothing. I lost my sanity over this, and I&amp;rsquo;m convinced that the authors of FA2 did not fully understand this line either. Fortunately, my despair is now your enlightenment. Allow me to show you the way.&lt;/p>
&lt;p>If you look at the FA2 source code, you&amp;rsquo;ll see that it defines one last V SMEM view:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">using&lt;/span> &lt;span class="n">SmemLayoutVtNoSwizzle&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">decltype&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">get_nonswizzle_portion&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">SmemLayoutVt&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">sVtNoSwizzle&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sV&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">SmemLayoutVtNoSwizzle&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tOrV&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">thr_mma&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_fragment_B&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sVtNoSwizzle&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>This &lt;code>sVtNoSwizzle&lt;/code> is only used to derive the fragment shape for &lt;code>tOrV&lt;/code>. It leads us to assume that &lt;em>something&lt;/em> breaks if we don&amp;rsquo;t extract the &lt;code>nonswizzle_portion&lt;/code> of &lt;code>sVt&lt;/code>&amp;rsquo;s layout.&lt;/p>
&lt;p>When we declared our fragments for Q and K, we simply passed in the swizzled tensors &lt;code>sQ, sK&lt;/code> to &lt;code>partition_fragment&lt;/code>. So why don&amp;rsquo;t we do that for &lt;code>sVt&lt;/code>? You&amp;rsquo;d imagine it&amp;rsquo;s due to the transpose we apply on top of the swizzle atom. In this case you&amp;rsquo;d be right, with a few massive caveats.&lt;/p>
&lt;h4 id="breaking-the-fragment-shapes">Breaking the Fragment Shapes&lt;/h4>
&lt;p>If we compile the code for &lt;code>hdim=64,128&lt;/code> and pass in &lt;code>sVt&lt;/code> to &lt;code>partition_fragment_B&lt;/code>, all the layouts are identical to passing in &lt;code>sVtNoSwizzle&lt;/code>. Seemingly, it makes no difference. However, if we compile the code for &lt;code>hdim=32,96&lt;/code>, the fragment shapes end up different:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="o">==========&lt;/span> &lt;span class="n">kHeadDim&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">32&lt;/span> &lt;span class="n">kBlockN&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">64&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">kBlockKSmem&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">32&lt;/span> &lt;span class="n">kSwizzle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">==========&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nl">SmemLayoutKV&lt;/span> &lt;span class="p">:&lt;/span> &lt;span class="n">Sw&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="n">o&lt;/span> &lt;span class="n">_0&lt;/span> &lt;span class="n">o&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">_64&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_32&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">:&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_32&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nl">SmemLayoutVt&lt;/span> &lt;span class="p">:&lt;/span> &lt;span class="n">Sw&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="n">o&lt;/span> &lt;span class="n">_0&lt;/span> &lt;span class="n">o&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">_32&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_64&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">:&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_32&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nl">SmemLayoutVtNoSwizzle&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">_32&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_64&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">:&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_32&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">fragment_B&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sVtNoSwizzle&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">:&lt;/span> &lt;span class="p">((&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">),&lt;/span>&lt;span class="n">_4&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_4&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">:&lt;/span>&lt;span class="p">((&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">),&lt;/span>&lt;span class="n">_4&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_16&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">fragment_B&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sVt&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">:&lt;/span> &lt;span class="p">((&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">),(&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">),&lt;/span>&lt;span class="n">_4&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">:&lt;/span>&lt;span class="p">((&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">),(&lt;/span>&lt;span class="n">_16&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_32&lt;/span>&lt;span class="p">),&lt;/span>&lt;span class="n">_4&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// print non-swizzled sQ as well
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">fragment_A&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sQNoSwizzle&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">:&lt;/span> &lt;span class="p">((&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">),&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">:&lt;/span>&lt;span class="p">((&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_4&lt;/span>&lt;span class="p">),&lt;/span>&lt;span class="n">_16&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">fragment_A&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sQ&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">:&lt;/span> &lt;span class="p">((&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">),&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">:&lt;/span>&lt;span class="p">((&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_4&lt;/span>&lt;span class="p">),&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="n">_16&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>A couple of observations:&lt;/p>
&lt;ul>
&lt;li>The default &lt;code>print&lt;/code> function on the SMEM layouts prints some pseudo layout that shows the composition of our different pieces. The reason is that there is no homogenous shape/stride combo that can represent the swizzle logic. We&amp;rsquo;ll take a look at the layout visually in a bit.&lt;/li>
&lt;li>Something clearly &amp;ldquo;breaks&amp;rdquo; the fragments when we switch to non-multiples of 64. The second dimension of the V-fragment changes from a flat &lt;code>4&lt;/code> to a nested &lt;code>(2, 2)&lt;/code> layout. The Q-fragment&amp;rsquo;s outer strides flip as well. However, both shapes are still the same size in total.&lt;/li>
&lt;/ul>
&lt;p>To understand what&amp;rsquo;s going on, we have to understand what &lt;code>partition_fragment&lt;/code> is actually doing. Canonically, we pass in tensors as its argument, e.g. &lt;code>tiled_mma.partition_fragment_A(sQ)&lt;/code>. However, this is a bit of a red herring, since the partitioning &lt;em>only needs the layout&lt;/em>. We already saw this with &lt;code>partition_fragment_C&lt;/code>, which is a standalone function that only takes in the tiled MMA and a shape. In the source code&lt;sup id="fnref:6">&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref">6&lt;/a>&lt;/sup>, the function only uses the tensor&amp;rsquo;s layout and nothing else. The tensor argument exists solely to anchor the partitioning to the physical SMEM buffer&amp;ndash;a readability convention, not a functional requirement.&lt;/p>
&lt;p>We can see how the &lt;code>nonswizzle&lt;/code> function simply chops off the swizzle component on the left of the layout and leaves us with the raw non-swizzled shape. So using the non-swizzled shape is actually ideal: we only need to partition the SMEM shape to extract the correct fragment tiles&amp;ndash;the copy operation is independent of the underlying swizzling pattern. The dumb conclusion is that &lt;strong>we should have passed in a non-swizzled &lt;code>sQ&lt;/code> and &lt;code>sK&lt;/code> to the partitioner as well.&lt;/strong>&lt;/p>
&lt;p>However, the fact that the copy works regardless means there&amp;rsquo;s some deeper reason for why the fragment shapes stay consistent for &lt;code>hdim=64,128&lt;/code> and break for &lt;code>hdim=32,96&lt;/code>. If you&amp;rsquo;ve gotten to this point, you should just take the above conclusion and run. If you want to continue to lose your sanity, we&amp;rsquo;re going to understand why.&lt;/p>
&lt;h4 id="ok-lets-figure-out-why">Ok, Let&amp;rsquo;s Figure Out Why&lt;/h4>
&lt;p>The reason &lt;code>hdim=32&lt;/code> breaks our fragment shapes is due to how the source code handles swizzling for non-multiples of 64:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">static&lt;/span> &lt;span class="k">constexpr&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">kBlockKSmem&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">kHeadDim&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="mi">64&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">?&lt;/span> &lt;span class="mi">64&lt;/span> &lt;span class="o">:&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">static&lt;/span> &lt;span class="k">constexpr&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">kSwizzle&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">kBlockKSmem&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">64&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">?&lt;/span> &lt;span class="mi">3&lt;/span> &lt;span class="o">:&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">using&lt;/span> &lt;span class="n">SmemLayoutAtomQ&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="k">decltype&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">composition&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Swizzle&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kSwizzle&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockKSmem&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockKSmem&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>These are the same constants from &lt;a class="link" href="#kblocksmem" >&lt;code>kBlockSmem&lt;/code>&lt;/a>. The hdim=32/96 path uses &lt;code>Swizzle&amp;lt;2,3,3&amp;gt;&lt;/code> instead of &lt;code>&amp;lt;3,3,3&amp;gt;&lt;/code> because the SMEM row is only 32 halfs wide &amp;ndash; 4 columns per 128-bit load, two rows per bit-mask row &amp;ndash; enough to clear conflicts without the wider permutation.&lt;/p>
&lt;p>However, we can notice that our row and column bits S and B are no longer adjacent:&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle_233_333.png"
width="2154"
height="1208"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle_233_333_hu7849fed29fd6961eacfe8b04e97fb1a3_149018_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle_233_333_hu7849fed29fd6961eacfe8b04e97fb1a3_149018_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle_233_333_hu7849fed29fd6961eacfe8b04e97fb1a3_149018_2048x0_resize_lanczos_3.png 2048w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/swizzle_233_333.png 2154w"
loading="lazy"
alt="Swizzle bit masks for &amp;lt;3,3,3&amp;gt; and &amp;lt;2,3,3&amp;gt;."
class="gallery-image"
data-flex-grow="178"
data-flex-basis="427px"
>
&lt;/p>
&lt;p>For the &lt;code>Swizzle&amp;lt;2, 3, 3&amp;gt;&lt;/code> pattern, two rows $2n, 2n+1$ map to the same &amp;ldquo;row&amp;rdquo; since bank conflicts cycle every 64 halfs. It preserves the 5th offset bit so the swizzling only permutes the blocks within each true SMEM row. When we print out the resulting layout for &lt;code>sQ&lt;/code> and &lt;code>sVt&lt;/code>, we begin to understand:&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/sk_svt_layout.png"
width="2330"
height="1026"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/sk_svt_layout_hu9cb3d163ccec2f4856d73fbd4fd288ea_337383_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/sk_svt_layout_hu9cb3d163ccec2f4856d73fbd4fd288ea_337383_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/sk_svt_layout_hu9cb3d163ccec2f4856d73fbd4fd288ea_337383_2048x0_resize_lanczos_3.png 2048w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/sk_svt_layout.png 2330w"
loading="lazy"
alt="Layout of sK and sVt, kHeadDim=32, Swizzle&amp;lt;2,3,3&amp;gt;."
class="gallery-image"
data-flex-grow="227"
data-flex-basis="545px"
>
&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Tip:&lt;/strong> You can print any layout in latex with &lt;code>print_latex(layout)&lt;/code>. I rendered this in &lt;a class="link" href="https://overleaf.com" target="_blank" rel="noopener"
>overleaf&lt;/a>.&lt;/p>
&lt;/blockquote>
&lt;p>True to its definition, &lt;code>sVt&lt;/code>&amp;rsquo;s layout is a pure transposition of &lt;code>sK&lt;/code>. Notice something quite peculiar: for &lt;code>sK&lt;/code>, the index stride alternates between 32 and 40. If we think about our swizzle pattern for a second, this makes sense. Let&amp;rsquo;s look at which offsets end up in column 0:&lt;/p>
&lt;table>
&lt;thead>
&lt;tr>
&lt;th style="text-align:left">SMEM Rows&lt;/th>
&lt;th style="text-align:left">Shared Bit-Mask Row&lt;/th>
&lt;th style="text-align:left">Column to XOR to make 0&lt;/th>
&lt;th style="text-align:left">Actual Addresses&lt;/th>
&lt;/tr>
&lt;/thead>
&lt;tbody>
&lt;tr>
&lt;td style="text-align:left">Row 0, 1&lt;/td>
&lt;td style="text-align:left">Row 0&lt;/td>
&lt;td style="text-align:left">Column 0&lt;/td>
&lt;td style="text-align:left">&lt;code>0=0b0&lt;/code>, &lt;code>32=0b100000&lt;/code>&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td style="text-align:left">Row 1, 2&lt;/td>
&lt;td style="text-align:left">Row 1&lt;/td>
&lt;td style="text-align:left">Column 1&lt;/td>
&lt;td style="text-align:left">&lt;code>0b1001000=72&lt;/code>, &lt;code>0b1101000=104&lt;/code>&lt;/td>
&lt;/tr>
&lt;/tbody>
&lt;/table>
&lt;p>Since each even-odd SMEM row pair shares a bit row, the only offset bit difference is in the 5th untouched bit position, so the difference is 32. Between even-odd row pairs, the only way $a \oplus b = 0$ is when $a=b$. Therefore, when we increment our row by 1, the column also increments by 1. The offset therefore increments by &lt;code>0b1010000&lt;/code>, which is 40.&lt;/p>
&lt;p>For &lt;code>Swizzle&amp;lt;3,3,3&amp;gt;&lt;/code>, no SMEM rows share a bit-mask row. Therefore, an offset in the same column as the row before it must increment by &lt;code>0b1001000&lt;/code>, which is 72. Each index going down each column in &lt;code>sK&lt;/code> (or row in &lt;code>sVt&lt;/code>) is a constant 72 stride away. Since hdims that are multiples of 64 use &lt;code>kBlockKSmem=64&lt;/code>, we have a fully unique set of row and column bits with a constant 72 stride. This is not true for pure multiples of 32. This one- vs. two-way stride inconsistency hints that &lt;code>partition_fragment&lt;/code> can&amp;rsquo;t extract a simple shape/stride pattern from the 32-hdim swizzled layout.&lt;/p>
&lt;p>When I figured this out, I didn&amp;rsquo;t bother looking into how CuTe actually computes the output fragment. The culprit is clearly some inability to extract a flat layout, whatever the reason. Furthermore, even though FA2 non-swizzles &lt;code>sVt&lt;/code>, it still uses a botched-up version of &lt;code>sQ&lt;/code>&amp;hellip;but the algorithm still works. So I wondered: what if I just replaced &lt;code>sVtNoSwizzle&lt;/code> with &lt;code>sVt&lt;/code>, even for &lt;code>hdim=32&lt;/code>? Something must break, right?&lt;/p>
&lt;h4 id="bruh">Bruh&lt;/h4>
&lt;p>Yeah, nope. I tested the kernel and it worked perfectly&amp;ndash;exact same output as the non-swizzled version, to the decimal. At this point I had spent 6 or 7 hours trying to figure out why this stupid line was there, only to realize it never mattered anyway. Someone at some point must&amp;rsquo;ve copied snippets from some CuTe example or other kernel, or got scared that the debug prints for the layouts looked wrong.&lt;/p>
&lt;p>This still begs the question: how does the kernel work despite these wonky layouts? The hint actually lies in the CuTe source code.&lt;/p>
&lt;h4 id="cute-source-code">CuTe Source Code&lt;/h4>
&lt;p>If we do a bit of digging into &lt;code>partition_fragment&lt;/code>&amp;rsquo;s source code&lt;sup id="fnref1:6">&lt;a href="#fn:6" class="footnote-ref" role="doc-noteref">6&lt;/a>&lt;/sup>, we find the call stack eventually calls &lt;code>make_fragment_like&lt;/code>, which has this cute little comment next to it&lt;sup id="fnref:7">&lt;a href="#fn:7" class="footnote-ref" role="doc-noteref">7&lt;/a>&lt;/sup>:&lt;/p>
&lt;blockquote>
&lt;p>&lt;code>make_fragment_like&lt;/code>: Make a tensor the same shape and (if possible) order as another tensor, with special
consideration of the 0th mode. The 0th mode is commonly used for MMA_Atoms or Copy_Atoms
so this allocates the 0th mode with LayoutLeft regardless of the reference layout.&lt;/p>
&lt;/blockquote>
&lt;p>This is the infuriating aha moment. The fragment attempts to copy the layout of the source as-is but strictly maintains the 0th dim shape with a column-major stride. If we look at our &lt;a class="link" href="#breaking-the-fragment-shapes" >fragment shape prints from earlier&lt;/a>, we see the &amp;ldquo;correct&amp;rdquo; and botched fragments have the same 0th dim shape with default column-major strides. If you recall, this dim in the tiled MMA represents the value layout within a given tile. Slowly, we realize this is the only register indexing that matters. Although the tile outer dims are different, they only represent a register mapping and no physical memory. Since all the load and store operations to the register fragments are consistent, the copy and MMA ops are mathematically consistent too. Each tile just lives somewhere else, but consistently referenced. Since its value layout is necessarily correct, the kernel is right just the same.&lt;/p>
&lt;h4 id="a-hilariously-simple-fix">A Hilariously Simple Fix&lt;/h4>
&lt;p>Remember when we simply let FA2 slide with this swizzle declaration?&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">static&lt;/span> &lt;span class="k">constexpr&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">kSwizzle&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">kBlockKSmem&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">64&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">?&lt;/span> &lt;span class="mi">3&lt;/span> &lt;span class="o">:&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">...&lt;/span>&lt;span class="n">Swizzle&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kSwizzle&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{}...;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Yeah, me neither. For &lt;code>kBlockKSmem=32&lt;/code>, we said that &lt;code>B=2&lt;/code> because there are only four 8-half columns per row (2 bits). But does it have to be? Physical SMEM is just L1 cache and has no concept of a layout&amp;ndash;that&amp;rsquo;s purely software. We can simply treat each even-odd row pair as one 64-half row by setting &lt;code>kSwizzle=3&lt;/code>. Our &lt;code>B&lt;/code> bitmask just treats the underlying SMEM as if it were 64 wide again. Just like before, each pseudo-row is conflict-free and the tiled layout still extends 8 rows deep for our thread copy. I scratched my head for a while thinking that this would fix all the shape problems from earlier but break some access pattern somewhere else. Like a good engineer, I simply changed &lt;code>kSwizzle=3&lt;/code> in my code and tested it&amp;hellip;and&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/tests_pass.png"
width="243"
height="208"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/tests_pass.png 243w"
loading="lazy"
alt="Testing passed."
class="gallery-image"
data-flex-grow="116"
data-flex-basis="280px"
>
&lt;/p>
&lt;p>So it seems that a variable &lt;code>kSwizzle&lt;/code> pattern isn&amp;rsquo;t necessary, even for smaller hdims. We don&amp;rsquo;t need to replicate our physical SMEM layout&amp;ndash;we just need the swizzle to do its job, and it does its job just the same with a constant &lt;code>Swizzle&amp;lt;3,3,3&amp;gt;&lt;/code>. Honestly, we found a simplification after clawing at our faces for hours on end&amp;ndash;a one line change. Worth it.&lt;/p>
&lt;h4 id="there-are-multiple-options">There are Multiple Options&lt;/h4>
&lt;p>There&amp;rsquo;s technically nothing wrong with the code as-is, but we have a few options to improve the confusion that &lt;code>sVtNoSwizzle&lt;/code> introduces:&lt;/p>
&lt;ol>
&lt;li>Add no-swizzle versions for Q and K as well. This is the most clear about what &lt;code>partition_fragment&lt;/code> is supposed to care about&amp;ndash;the unswizzled shape. It might read as a requirement, which it is not, but it is the most direct with its intention.&lt;/li>
&lt;li>Remove &lt;code>sVtNoSwizzle&lt;/code>. It&amp;rsquo;s not strictly necessary, and it introduces the underlying assumptions that we now know don&amp;rsquo;t exist. I assume it was added because a debug print statement showed that the V fragment had a strange shape, which may cause confusion to developers.&lt;/li>
&lt;li>Change the swizzle pattern to &lt;code>Swizzle&amp;lt;3,3,3&amp;gt;&lt;/code> for all relevant hdims. We can pair this with 2., which fully removes any shape inconsistencies. The swizzle pattern doesn&amp;rsquo;t match the physical SMEM layout at &lt;code>hdim=32,96&lt;/code>, but to be fair, neither does using &lt;code>kBlockKSmem=64&lt;/code> for &lt;code>hdim=128&lt;/code>.&lt;/li>
&lt;/ol>
&lt;p>The choice is ultimately up to you. I like number 3 since it&amp;rsquo;s the simplest and causes no shape inconsistencies, which is what I will use below. Let&amp;rsquo;s refactor the code snippet we introduced in this subsection:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;span class="lnt">8
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">static&lt;/span> &lt;span class="k">constexpr&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">kBlockKSmem&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">kHeadDim&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="mi">64&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">?&lt;/span> &lt;span class="mi">64&lt;/span> &lt;span class="o">:&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">using&lt;/span> &lt;span class="n">SmemLayoutAtomQ&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="k">decltype&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">composition&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Swizzle&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockKSmem&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Stride&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockKSmem&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">...&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">sVt&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sV&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">SmemLayoutVt&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tOrV&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">thr_mma&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_fragment_B&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sVt&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;blockquote>
&lt;p>Like the Q and K SMEM copy, the V SMEM copy is done inside its GEMM loop. We&amp;rsquo;ll cover this &lt;a class="link" href="#putting-it-all-together" >after we cover softmax&lt;/a>, since we have to compute &lt;code>S&lt;/code> first.&lt;/p>
&lt;/blockquote>
&lt;h2 id="the-actual-async-copy-strategy">The Actual Async Copy Strategy&lt;/h2>
&lt;p>At this point, we&amp;rsquo;ve more than covered &lt;em>how&lt;/em> to copy. Now, let&amp;rsquo;s explain &lt;em>when&lt;/em> to copy.&lt;/p>
&lt;p>We covered much of the strategy all the way back at the &lt;a class="link" href="#the-kernel-outline" >beginning&lt;/a>, so feel free to take a moment to review.&lt;/p>
&lt;h3 id="lets-assume-the-fa2-people-did-the-work">Let&amp;rsquo;s Assume the FA2 People Did the Work&lt;/h3>
&lt;p>There are a dozen ways to schedule the async loads. K tile here or there, V tile pipelined two-deep, fences merged or split. We have some heuristics that tell us when to do what, but every decision had to be empirically proven at some point. You might ask, why did they specifically choose &lt;em>this&lt;/em>, and the answer is probably I&amp;rsquo;m not sure. Either the choice isn&amp;rsquo;t that important to begin with or there was some specific advantage of doing it that way. We will assume Dao and his team did the work instead of questioning every decision. If you want to chase the rest down with &lt;code>ncu&lt;/code>, be my guest.&lt;/p>
&lt;h3 id="first-up-q-tile-and-k-tile-prefetch">First Up: Q-Tile and K-Tile Prefetch&lt;/h3>
&lt;p>The Q tile remains the same throughout the entire thread block, so it&amp;rsquo;s arguably the one that benefits the least from the async copy. We can still overlap a small amount of compute before our main loop. At the same time, we can fetch our 0th K-tile to prepare for the immediate $QK^T$ once we hit the main loop.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">gmem_tiled_copy_QKV&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tQgQ&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tQsQ&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// issue first K copy tile &amp;#34;0&amp;#34;
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">gmem_tiled_copy_QKV&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tKgK&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_0&lt;/span>&lt;span class="p">{}),&lt;/span> &lt;span class="n">tKsK&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">cp_async_fence&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>The QK copies are immediately issued into the background, and we add a &lt;code>cute::cp_async_fence()&lt;/code> to establish a commit point (i.e. a barrier) that tracks all the async copies before it (in this case, two). This function is used in tandem with &lt;code>cute::cp_async_wait&amp;lt;N&amp;gt;()&lt;/code>, which blocks the current thread until only N batches of async copies are still outstanding (e.g., &lt;code>cp_async_wait&amp;lt;0&amp;gt;()&lt;/code> waits until all batches are completely finished).&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: &lt;code>cp_async_wait&amp;lt;N&amp;gt;&lt;/code> only tells us the threads have finished loading, not that they&amp;rsquo;re necessarily at the same point. Most of the time this means you have to manually call &lt;code>__syncthreads()&lt;/code> since the threads that load the data are not usually the only threads that end up touching the data.&lt;/p>
&lt;/blockquote>
&lt;p>This fence-wait pattern is extremely common for software pipelining. For example, each loop iteration we could fetch the next 10 K blocks if we had enough compute to overlap the loads. In most production GEMM kernels, each loop typically fetches two or three blocks in advance. At each loop, you might call &lt;code>wait&amp;lt;1&amp;gt;&lt;/code> to wait for the latest block, allowing you to run multiple loop iterations without stalling instead of waiting for the next block each time.&lt;/p>
&lt;p>FA2 only uses a one-block prefetch. At each iteration, we only prefetch the immediate next block. The three main reasons are register pressure, SMEM limits, and tile sizing. FA2 gets quite close to the register limit per thread since each one has to store Q, K, V, the accumulator fragment, softmax statistics, and giant unrolled loops. Adding more memory address tracking and heavier loops pushes near the register ceiling. Furthermore, each prefetch means adding an extra tile buffer in SMEM. Our tiles are up to size $128\times 128$, which have a huge SMEM footprint. Doubling or tripling up these buffers would likely crush occupancy.&lt;/p>
&lt;h3 id="main-loop">Main Loop&lt;/h3>
&lt;p>In our main loop, we prefetch K and V every iteration to keep the next blocks coming in while we do our MMAs and softmax. We issue the K prefetch during softmax and the $SV$ GEMM, and the V prefetch during the $QK^T$ to maximize our overlap. At the start of each iteration, we wait on our K-tile. Once it&amp;rsquo;s loaded, we issue our V prefetch and immediately begin our MMA so the V-copy overlaps this compute period. Then we wait on the V-tile. Once it&amp;rsquo;s loaded, we issue our K copy and immediately begin our V GEMM.&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/main_loop.png"
width="2360"
height="1312"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/main_loop_hu2343f0e7a0324f5ede8f0b989a56f96a_217582_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/main_loop_hu2343f0e7a0324f5ede8f0b989a56f96a_217582_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/main_loop_hu2343f0e7a0324f5ede8f0b989a56f96a_217582_2048x0_resize_lanczos_3.png 2048w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/main_loop.png 2360w"
loading="lazy"
alt="Main Loop Flow"
class="gallery-image"
data-flex-grow="179"
data-flex-basis="431px"
>
&lt;/p>
&lt;p>FA2 uses the simplest fence-wait strategy possible. Issue one copy and fully wait on that to load before issuing the next async copy. This keeps the logic the easiest, as we always know which block is ready, and we&amp;rsquo;ll assume this allows it to manually control and optimize the load/compute overlap.&lt;/p>
&lt;p>The main loop iterates over all K, V tiles, each spanning the N dimension (see &lt;a class="link" href="#v-smem-register" >the SMEM tiled layout&lt;/a>). We can simply write an unrolled for loop that iterates over the &lt;code>seqlen_k / kBlockN&lt;/code> dimensions. We can begin to fill in our strategy based on what we described above:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;span class="lnt">33
&lt;/span>&lt;span class="lnt">34
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">const&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">nBlocksN&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">ceil_div&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">seqlen_k&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kBlockN&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">#pragma unroll
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span>&lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">nblock&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">nblock&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">nBlocksN&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">nblock&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// FA2 actually defines acc_s here instead of outside the loop
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="c1">// It&amp;#39;s more a code signal that acc_s belongs to the loop
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="c1">// It&amp;#39;s not actually allocating anything each iteration
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">Tensor&lt;/span> &lt;span class="n">acc_s&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">partition_fragment_C&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">tiled_mma&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockM&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockN&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">clear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_s&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// wait on K
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">cp_async_wait&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// We NEED this since our threads who load now need to do work
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="c1">// on their actual data
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">__syncthreads&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// issue V copy
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">gmem_tiled_copy_QKV&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tVgV&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">nblock&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">tVsV&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">cp_async_fence&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// do QK gemm, pseudocode
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">gemm_QK&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// wait on V
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">cp_async_wait&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">__syncthreads&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// next K block prefetch
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="k">if&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">nblock&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">nBlocksN&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span> &lt;span class="c1">// not last block
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">gmem_tiled_copy_QKV&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tKgK&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">nblock&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">tKsK&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">cp_async_fence&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// softmax rescale + SV Gemm
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>The Q tile and K tile 0 load are in the same fence, and all subsequent K, V loads are under their own fence. Note how we issue the current loop&amp;rsquo;s V-tile every iteration and the &lt;em>next&lt;/em> K-tile every iteration, since next K copy is fetched while the current V work is going on. As before, you might wonder why we issue the K-copy before the softmax instead of after or whatnot &amp;ndash; I assume this location maximizes overlap before K is needed again, but I&amp;rsquo;m not sure. I guess, &lt;a class="link" href="#lets-assume-the-fa2-people-did-the-work" >let&amp;rsquo;s assume they did the work&lt;/a>. You are welcome to move things around and test + benchmark it to verify.&lt;/p>
&lt;p>We&amp;rsquo;ve finally arrived at the end of our Q, K, V copy journey, and you&amp;rsquo;ll see how these simple things required a lot of work and understanding. Next, we&amp;rsquo;re going to cover the online softmax and $SV$ and fill them into the main loop. Then, our final step will be more copying &amp;ndash; but this time storing the output from our fragments back into GMEM.&lt;/p>
&lt;h1 id="online-softmax">Online Softmax&lt;/h1>
&lt;blockquote>
&lt;p>&lt;strong>Source:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/softmax.cuh" target="_blank" rel="noopener"
>&lt;code>softmax.cuh&lt;/code>&lt;/a>&lt;/p>
&lt;p>&lt;strong>Play:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/test_softmax.cu" target="_blank" rel="noopener"
>&lt;code>scratch/test_softmax.cu&lt;/code>&lt;/a> (end-to-end &lt;code>softmax_rescale_o&lt;/code> test against a CPU reference)&lt;/p>
&lt;/blockquote>
&lt;p>After $QK^T$, we now deal with the online softmax. Fortunately, this step isn&amp;rsquo;t terribly difficult because of the way we set up the threads. To review, the softmax portion has a couple of steps:&lt;/p>
&lt;ol>
&lt;li>Calculate new per-row max: &lt;code>m_new = max(scores, dim=-1)&lt;/code>&lt;/li>
&lt;li>Apply max rescale and exponentiation on scores: &lt;code>scores_exp = exp(scores - m_new)&lt;/code>&lt;/li>
&lt;li>Calculate correction factor: &lt;code>correction = exp(m_old - m_new)&lt;/code>&lt;/li>
&lt;li>Apply correction to output/accumulator: &lt;code>acc *= correction&lt;/code>&lt;/li>
&lt;li>Apply correction and compute new expsum denominator: &lt;code>l = l*correction + scores_exp.sum(dim=-1)&lt;/code>&lt;/li>
&lt;/ol>
&lt;p>To track the softmax state, the source code opts for an organized softmax struct that keeps track of the rolling max and expsum registers per thread:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">kNRows&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="k">struct&lt;/span> &lt;span class="nc">Softmax&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">using&lt;/span> &lt;span class="n">TensorT&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="k">decltype&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_tensor&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="kt">float&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kNRows&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">TensorT&lt;/span> &lt;span class="n">row_max&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="c1">// running per-row max (m)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">TensorT&lt;/span> &lt;span class="n">row_sum&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="c1">// running per-row sum (l)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>You can opt for other equally valid approaches, but a softmax struct keeps the code clean and separate.&lt;/p>
&lt;h2 id="loop-scale-softmax">Loop: Scale Softmax&lt;/h2>
&lt;p>At each loop iteration, after we&amp;rsquo;ve computed &lt;code>S = Q@K.T&lt;/code>, we have to compute the online softmax on the resulting tensor S and update our max and sum. We define the method &lt;code>softmax_rescale_o&lt;/code> that is invoked after &lt;code>S&lt;/code> is first computed:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="kt">bool&lt;/span> &lt;span class="n">Is_first&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Tensor0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Tensor1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">__device__&lt;/span> &lt;span class="n">__forceinline__&lt;/span> &lt;span class="kt">void&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">softmax_rescale_o&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Tensor0&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">acc_s&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1">// (MMA, MMA_M, MMA_N) score block, fp32
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">Tensor1&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">acc_o&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1">// (MMA, MMA_M, MMA_K) output acc, fp32
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="kt">float&lt;/span> &lt;span class="n">softmax_scale_log2&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>So what actually goes into computing the max and the sum?&lt;/p>
&lt;h2 id="row-reduce">Row Reduce&lt;/h2>
&lt;p>This is the moment where the &lt;a class="link" href="#tiled-mma" >warp-per-row tiling&lt;/a> pays off. Each warp owns 16 rows of its MMA tile, fully&amp;ndash;no row is split across warps. Therefore, the per-row max and sum reductions stay inside a warp and resolve via &lt;strong>warp reduction&lt;/strong>, &amp;ldquo;a highly efficient CUDA parallel reduction technique that aggregates data across 32 threads within a single GPU warp.&amp;rdquo;&lt;sup id="fnref:8">&lt;a href="#fn:8" class="footnote-ref" role="doc-noteref">8&lt;/a>&lt;/sup> CUDA provides warp primitives such as &lt;code>__shfl_down_sync()&lt;/code> and &lt;code>__shfl_xor_sync()&lt;/code>, which shuffle data across threads in a warp without any load/stores or shared-memory staging. Zero memory latency, zero &lt;code>__syncthreads()&lt;/code>, and our max/sum is pretty much free.&lt;/p>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/xor.png"
width="2112"
height="1632"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/xor_hu631af095a94c75ca8e2dcd7248099ffd_220459_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/xor_hu631af095a94c75ca8e2dcd7248099ffd_220459_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/xor_hu631af095a94c75ca8e2dcd7248099ffd_220459_2048x0_resize_lanczos_3.png 2048w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/xor.png 2112w"
loading="lazy"
alt="XOR shuffle pattern. Thanks to Hyunsung Lee&amp;rsquo;s Blog for this graphic"
class="gallery-image"
data-flex-grow="129"
data-flex-basis="310px"
>
&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: Warp reduction is the primary and fastest way to perform intra-warp communication.&lt;/p>
&lt;/blockquote>
&lt;p>However, warp reduction is actually step two&amp;ndash;it finds the max/sum &lt;em>between threads&lt;/em>. We first have to find the max/sum &lt;em>per-thread&lt;/em>.&lt;/p>
&lt;h3 id="thread-reduce">Thread Reduce&lt;/h3>
&lt;p>Recall that every thread in a 16x8 MMA output fragment holds $16\cdot 8 / 32=4$ output values. To find the thread&amp;rsquo;s row max or row sum, we have to find the max/sum of all the elements the thread owns per row, across all the rows the thread touches. If you&amp;rsquo;re confused, don&amp;rsquo;t worry. It will make sense as we re-examine the &lt;a class="link" href="#tiled-mma" >MMA Atom&amp;rsquo;s thread layout&lt;/a> from earlier.&lt;/p>
&lt;p>Let&amp;rsquo;s look at the bottom right output fragment C this time and zoom in on thread 0&amp;rsquo;s values. We can see it owns output elements &lt;code>(0, 0), (0, 1), (8, 0), (8, 1)&lt;/code>. If we examine all the other threads, we see that they each own 4 elements across two rows. Let&amp;rsquo;s clarify the math a bit:&lt;/p>
&lt;p>From the &lt;a class="link" href="#mma-shape" >MMA Shape&lt;/a> section, the C-fragment has shape &lt;code>((2, 2), MMA_M, MMA_N)&lt;/code> where &lt;code>MMA_M = kBlockM / (16*kNWarps)&lt;/code> and &lt;code>MMA_N = kBlockN / 8&lt;/code>. The &lt;code>(2, 2)&lt;/code> is each thread&amp;rsquo;s 4 values per tile (2 rows, 2 columns) so each thread&amp;rsquo;s values span &lt;code>MMA_M * 2&lt;/code> rows and &lt;code>MMA_N * 2&lt;/code> columns total.&lt;/p>
&lt;p>To make the reduction loop look like ordinary 2D code, Dao reshapes this hierarchical fragment into a flat row-major &lt;code>(2*MMA_M, 2*MMA_N)&lt;/code> view. Since the block sizes are static, the reshape is free. It&amp;rsquo;s purely a code-quality trick; the resulting PTX is identical to iterating over the raw MMA shape.&lt;/p>
&lt;h3 id="fragment-reshape">Fragment Reshape&lt;/h3>
&lt;blockquote>
&lt;p>&lt;strong>Source:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/utils.cuh" target="_blank" rel="noopener"
>&lt;code>utils.cuh&lt;/code>&lt;/a> (&lt;code>convert_layout_rowcol&lt;/code>)&lt;/p>
&lt;p>&lt;strong>Play:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/fragment_reshape.cu" target="_blank" rel="noopener"
>&lt;code>scratch/fragment_reshape.cu&lt;/code>&lt;/a>&lt;/p>
&lt;/blockquote>
&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/layout_row_col.png"
width="2310"
height="1270"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/layout_row_col_hu75a85792dfe651d98ea8fc8269578250_161303_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/layout_row_col_hu75a85792dfe651d98ea8fc8269578250_161303_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/layout_row_col_hu75a85792dfe651d98ea8fc8269578250_161303_2048x0_resize_lanczos_3.png 2048w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/layout_row_col.png 2310w"
loading="lazy"
alt="Visualization of reshape from left, ((2,2), MMA_M, MMA_N), to right, row x column format: (2xMMA_M, 2xMMA_N)."
class="gallery-image"
data-flex-grow="181"
data-flex-basis="436px"
>
&lt;/p>
&lt;p>We can create an inline function for this reshape. The CuTe methods are not super important to know as long as you understand the mechanism behind them; CuTe provides some out-of-the-box methods that make this layout algebra slightly easier for us. Let&amp;rsquo;s understand some shapes with an example: &lt;code>kBlockM=kBlockN=128, kNWarps=1&lt;/code>.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// acc_s is the accumulator for Q@K.T
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_s&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">layout&lt;/span>&lt;span class="p">());&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// ((_2,_2),_8,_16):((_1,_2),_4,_32)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// Target output shape: (16, 32)
&lt;/span>&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>We want our output shape to be &lt;code>(2*MMA_M, 2*MMA_N)&lt;/code> to mirror standard 2D row-major format, which means we have to distribute this &lt;code>(2, 2)&lt;/code> MMA atom shape across our tile dimensions. The Atom standard is actually column-major, which means &lt;code>(2,2)=(j, i)&lt;/code>. This is a hardware-layout choice and can be an area of confusion. We can verify this by printing the C fragment thread layout:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// TV = thread value
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">print_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tiled_mma&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">get_layoutC_TV&lt;/span>&lt;span class="p">());&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>
&lt;img src="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_thread_value_layout.png"
width="1312"
height="1314"
srcset="https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_thread_value_layout_huf5ef5f6c8af0c31e3f8ca5d56f73a263_186770_480x0_resize_lanczos_3.png 480w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_thread_value_layout_huf5ef5f6c8af0c31e3f8ca5d56f73a263_186770_1024x0_resize_lanczos_3.png 1024w, https://blog.echen.io/p/flashattention-2-in-cute-from-scratch/mma_thread_value_layout.png 1312w"
loading="lazy"
alt="MMA C fragment thread value layout. The row labels are the thread numbers and the columns are the thread values. I truncated it at threads 0-7 for brevity, but the full print shows all 32 threads. There are 8 thread values instead of 4 because this is the full output C tile which is two 16x8 atoms to form the 16x16 output; this is equivalent to the values in two adjacent N-tiles in acc_s."
class="gallery-image"
data-flex-grow="99"
data-flex-basis="239px"
>
&lt;/p>
&lt;p>We can see that thread 0&amp;rsquo;s values &lt;code>(0, 0)&lt;/code> and &lt;code>(0, 1)&lt;/code> (values 0 and 1) are at memory locations 0 and 16, while &lt;code>(8, 0)&lt;/code> and &lt;code>(8, 1)&lt;/code> are at 8 and 24. Since column 0, row 8 is at memory location 8, the thread values are column-major. Therefore, to grab the 2nd row, 1st element (&lt;code>i=1, j=0&lt;/code>) at tile &lt;code>(4, 3)&lt;/code>, we&amp;rsquo;d index &lt;code>((0, 1), 4, 3)&lt;/code> in the original layout. For our reshaped layout, we cannot simply reinterpret the shape as &lt;code>(2*MMA_M, 2*MMA_N)&lt;/code> because neither the rows nor columns are contiguous in memory in our tiled fragment. Instead, we rely on our handy-dandy hierarchical layouts. We know each &lt;code>MMA_M&lt;/code> and &lt;code>MMA_N&lt;/code> has two values, so we can map them to 2D via a composed 2D layout: &lt;code>((MMA_M, 2), (MMA_N, 2))&lt;/code>. Since CuTe is column-major, &lt;code>(MMA_M, 2)&lt;/code> iterates over the &lt;code>MMA_M&lt;/code> dimension first (every other row, since each M tile is two rows). In our row-major orientation, we want adjacent indices to be adjacent rows, so we remedy this by flipping the dims: &lt;code>((2, MMA_M), (2, MMA_N))&lt;/code>. Fixing the strides is easy; we just map each old stride to the new location in the new shape, and that&amp;rsquo;s it:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># pseudo code&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">old_layout&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">((&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="mi">8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">16&lt;/span>&lt;span class="p">):((&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">new_layout&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">((&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">8&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">16&lt;/span>&lt;span class="p">)):((&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>This composed layout can be addressed as shape &lt;code>(16, 32)&lt;/code> even though it has sub-layouts. We can index it via &lt;code>(i, j)&lt;/code> and CuTe figures out the math under the hood. The stride math is identical: we match each layout dimension with its original stride but redistribute the dimensions so our final output shape is &lt;code>(2*MMA_M, 2*MMA_N)&lt;/code>. Let&amp;rsquo;s verify with our example:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># IDX = ((0, 1), 4, 3) = ((col, row), m_tile, n_tile)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># left value is index, right is stride&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">original_address&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="mi">1&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="mi">4&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="mi">3&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">114&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># compute i, j for IDX&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># m_tile * 2 rows/tile + 1st row&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">i&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="mi">2&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">9&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># n_tile * 2 cols/tile + 0th col&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">j&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="mi">2&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">0&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">6&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Remember, colex indexing&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># convert i=9 to layout (2, 8)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">i_tuple&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mi">9&lt;/span> &lt;span class="n">mod&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">9&lt;/span>&lt;span class="o">/&lt;/span>&lt;span class="mi">2&lt;/span> &lt;span class="n">mod&lt;/span> &lt;span class="mi">8&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># convert i=9 to layout (2, 16)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">j_tuple&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mi">6&lt;/span> &lt;span class="n">mod&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">6&lt;/span>&lt;span class="o">/&lt;/span>&lt;span class="mi">2&lt;/span> &lt;span class="n">mod&lt;/span> &lt;span class="mi">16&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">final_address&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="mi">2&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="mi">1&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">114&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>All this row-major/column-major conversion is extremely confusing and was a huge source of unbelievable headache. All this work simply for a reshape in code. You could have kept the column-major ordering or not reshaped at all, but at least you can now understand the FA2 production source code. Dao implements this approach like so:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="k">typename&lt;/span> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">__forceinline__&lt;/span> &lt;span class="n">__device__&lt;/span> &lt;span class="k">auto&lt;/span> &lt;span class="n">convert_layout_rowcol&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Layout&lt;/span> &lt;span class="k">const&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">in&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// (MMA, MMA_M, MMA_N), MMA=4 -&amp;gt; (2,2)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="k">auto&lt;/span> &lt;span class="n">sl&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">logical_divide&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">in&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{});&lt;/span> &lt;span class="c1">// ((2, MMA/2), MMA_M, MMA_N)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="k">return&lt;/span> &lt;span class="nf">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sl&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sl&lt;/span>&lt;span class="p">)),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sl&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sl&lt;/span>&lt;span class="p">)));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>In this code, the &lt;code>logical_divide&lt;/code> is actually a no-op. CuTe already gives us &lt;code>acc_s&lt;/code> as &lt;code>((2, 2), MMA_M, MMA_N)&lt;/code>. This bit of code ensures that if we were somehow given &lt;code>MMA=4&lt;/code> instead of &lt;code>(2, 2)&lt;/code>, the function would column-divide the shape to give us what we expect. I&amp;rsquo;m not sure why the &lt;code>logical_divide&lt;/code> is here, but it doesn&amp;rsquo;t break anything. Since it&amp;rsquo;s a static divide, the compiler optimizes everything to the same PTX regardless.&lt;/p>
&lt;h3 id="thread-reduce-continued">Thread Reduce, Continued&lt;/h3>
&lt;p>With the fragment in row/col view, the per-thread reduction trivializes. Each thread can simply iterate through all its values and compute the max and the sum. This happens at the per-thread register level and is very fast.&lt;/p>
&lt;p>We have two reduction operations, &lt;code>max&lt;/code> and &lt;code>sum&lt;/code>. It&amp;rsquo;s common to define them as functional structs for portability:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">struct&lt;/span> &lt;span class="nc">MaxOp&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">__device__&lt;/span> &lt;span class="n">__forceinline__&lt;/span> &lt;span class="kt">float&lt;/span> &lt;span class="nf">operator&lt;/span>&lt;span class="p">()(&lt;/span>&lt;span class="kt">float&lt;/span> &lt;span class="n">a&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="kt">float&lt;/span> &lt;span class="n">b&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="k">const&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">a&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="n">b&lt;/span> &lt;span class="o">?&lt;/span> &lt;span class="nl">a&lt;/span> &lt;span class="p">:&lt;/span> &lt;span class="n">b&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">struct&lt;/span> &lt;span class="nc">SumOp&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">__device__&lt;/span> &lt;span class="n">__forceinline__&lt;/span> &lt;span class="kt">float&lt;/span> &lt;span class="nf">operator&lt;/span>&lt;span class="p">()(&lt;/span>&lt;span class="kt">float&lt;/span> &lt;span class="n">a&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="kt">float&lt;/span> &lt;span class="n">b&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="k">const&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">a&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="n">b&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">};&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Next, we define the actual thread reduce function. It&amp;rsquo;s just a nested for-loop that keeps track of the per-row max and sum in a register tensor. Like pretty much all other loops in our kernel, this one has kernel-constant iterations and plenty of register reuse, so we can fully unroll it.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="kt">bool&lt;/span> &lt;span class="n">zero_init&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="nb">true&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Engine0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Layout0&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">typename&lt;/span> &lt;span class="n">Engine1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Layout1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Op&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">__device__&lt;/span> &lt;span class="n">__forceinline__&lt;/span> &lt;span class="kt">void&lt;/span> &lt;span class="n">thread_reduce_&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Tensor&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Engine0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Layout0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="k">const&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1">// (M, N)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">Tensor&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Engine1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Layout1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">dst&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1">// (kNRows,) per-row scratch
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">Op&lt;/span> &lt;span class="n">op&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">#pragma unroll
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">row&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">row&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="n">row&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">dst&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">row&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">zero_init&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">?&lt;/span> &lt;span class="n">tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">row&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">:&lt;/span> &lt;span class="n">op&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dst&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">row&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">row&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">#pragma unroll
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">col&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">col&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="n">col&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">dst&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">row&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">op&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">row&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">col&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">dst&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">row&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>We use the &lt;code>size&amp;lt;&amp;gt;&lt;/code> declarator to get the row and column sizes, and we add a &lt;code>zero_init&lt;/code> template variable to initialize the first softmax call. That&amp;rsquo;s it!&lt;/p>
&lt;h3 id="warp-reduce">Warp Reduce&lt;/h3>
&lt;blockquote>
&lt;p>&lt;strong>Source:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/softmax.cuh" target="_blank" rel="noopener"
>&lt;code>softmax.cuh&lt;/code>&lt;/a> (&lt;code>Allreduce&amp;lt;N&amp;gt;&lt;/code>)&lt;/p>
&lt;p>&lt;strong>Play:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/scratch/test_allreduce.cu" target="_blank" rel="noopener"
>&lt;code>scratch/test_allreduce.cu&lt;/code>&lt;/a>&lt;/p>
&lt;/blockquote>
&lt;p>Now, each thread has its max and sum, so we warp-reduce the max and sum across all threads. CUDA and GPUs follow a tree-reduce paradigm: instead of looping over all threads in $O(n)$ time, pairs of threads reduce among each other at each step, single-elimination bracket-style. Each iteration, we reduce half the threads, so at the end we only require $O(\log(n))$ iterations to find the final max.&lt;/p>
&lt;p>The simplest strategy is where each thread pairs up with the thread $N/2$ above it. Thread 0 pairs with 16, 1 with 17, until 15 with 31. Threads 0-15 have the max. Then Thread 0 pairs with 8, 1 with 9, until 7 and 15. At each step we halve the step (16-&amp;gt;8-&amp;gt;4-&amp;gt;2-&amp;gt;1), until thread 0 has the final max. CUDA calls this reduction &lt;code>__shfl_down_sync()&lt;/code>, which would be good enough except that only one thread ends up with the final value. However, in our case, each thread needs to know the max/sum to calculate the final softmax. Instead, we use the &lt;code>__shfl_xor_sync()&lt;/code> primitive. You might tense up at the idea of XOR again, but I&amp;rsquo;m not going to explain it this time. As with swizzling, the primitive creates a bit-sharing mask such that all threads pair up in a way that lets them all end up with the final value.&lt;sup id="fnref:9">&lt;a href="#fn:9" class="footnote-ref" role="doc-noteref">9&lt;/a>&lt;/sup>&lt;/p>
&lt;p>Most shuffle primitives take in a bitmask of all participating threads, a value, and a stride:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">__shfl_xor_sync&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">uint&lt;/span> &lt;span class="n">b_32bit_mask&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">T&lt;/span> &lt;span class="n">value&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">uint&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Most of the time, all 32 threads participate, so the mask is all 1s (&lt;code>0xffffffff&lt;/code>). The stride is how many threads apart we look. With 32 threads, this starts at 16 and goes down to 1. We loop over all possible strides until the value is fully reduced:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;span class="lnt">8
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">N&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">T&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Op&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">__device__&lt;/span> &lt;span class="n">__forceinline__&lt;/span> &lt;span class="n">T&lt;/span> &lt;span class="n">allreduce&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">T&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Op&lt;/span> &lt;span class="n">op&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">#pragma unroll
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">stride&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">N&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">stride&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">stride&lt;/span> &lt;span class="o">/=&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">op&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">__shfl_xor_sync&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mh">0xffffffff&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="quad-reduce">Quad Reduce&lt;/h3>
&lt;p>The reason we give &lt;code>allreduce&lt;/code> a template variable &lt;code>N&lt;/code> is that we&amp;rsquo;re not actually reducing across all 32 threads. We only need to reduce over the number of threads that own each row. In the &lt;a class="link" href="#thread-reduce" >MMA Atom Layout image&lt;/a> from earlier, we see that four adjacent threads collectively own each row. Therefore, &lt;code>N=4&lt;/code>&amp;ndash;hence, &amp;ldquo;quad&amp;rdquo; reduce. The XOR primitive automatically reduces between participating threads, so we don&amp;rsquo;t have to iterate in groups of four&amp;ndash;the sync instruction waits for each four-thread group to enter the reduction. Our quad reduce function is therefore quite simple:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="k">typename&lt;/span> &lt;span class="n">Engine0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Layout0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Engine1&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">typename&lt;/span> &lt;span class="n">Layout1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Op&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">__device__&lt;/span> &lt;span class="n">__forceinline__&lt;/span> &lt;span class="kt">void&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">quad_allreduce_&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Tensor&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Engine0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Layout0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">dst&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1">// (kNRows,) per-row reduced
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">Tensor&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Engine1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Layout1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">src&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1">// (kNRows,) per-row local
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">Op&lt;/span> &lt;span class="n">op&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">#pragma unroll
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">row&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">row&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">src&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="n">row&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">dst&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">row&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">allreduce&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">src&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">row&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">op&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Now we can create some functions that wrap the thread and quad reduces to get our max and sum.&lt;/p>
&lt;h3 id="reduce-sum-and-max">Reduce Sum and Max&lt;/h3>
&lt;p>For &lt;code>reduce_max()&lt;/code>, all we need to do is call thread reduce followed by quad reduce. The sync during &lt;code>quad_allreduce()&lt;/code> handles the thread sync:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;span class="lnt">8
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="kt">bool&lt;/span> &lt;span class="n">zero_init&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="nb">true&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Engine0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Layout0&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">typename&lt;/span> &lt;span class="n">Engine1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Layout1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">__forceinline__&lt;/span> &lt;span class="n">__device__&lt;/span> &lt;span class="kt">void&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">reduce_max&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Tensor&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Engine0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Layout0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="k">const&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Tensor&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Engine1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Layout1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">max&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">thread_reduce_&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">zero_init&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">max&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">MaxOp&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">quad_allreduce_&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">max&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">max&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">MaxOp&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>For &lt;code>reduce_sum()&lt;/code>, we don&amp;rsquo;t actually have to quad reduce the final sum until the very end. FA2 brought a slight optimization: the final expsum division only needs to happen after the final softmax rescale. This saves us quite a few sums and multiplications per iteration.&lt;/p>
&lt;p>All we need to do is update the thread sums along the way and scale by the correction factor.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="kt">bool&lt;/span> &lt;span class="n">zero_init&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="nb">true&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Engine0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Layout0&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">typename&lt;/span> &lt;span class="n">Engine1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Layout1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">__forceinline__&lt;/span> &lt;span class="n">__device__&lt;/span> &lt;span class="kt">void&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">reduce_sum&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Tensor&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Engine0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Layout0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="k">const&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Tensor&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Engine1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Layout1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// we defer allreduce until after all iterations are completed
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="c1">// so we only do curr_sum * correction + new_sum until the final
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="c1">// iteration. Saves unnecessary aggregation/registers
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">thread_reduce_&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">zero_init&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sum&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">SumOp&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>At each step, we have the rowsum per thread. After each rescale, we multiply by the correction factor (covered next): &lt;code>rowsum *= correction&lt;/code>. Then we call the reduction: &lt;code>rowsum += sum(v for v in thread_row)&lt;/code>.&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: We cannot reduce max and sum at the same time. Although it seems like it would be more efficient, we need the full row max to compute the exp2 sum, not just the per-thread max.&lt;/p>
&lt;/blockquote>
&lt;h2 id="exp2-and-calculating-exprow">Exp2, and Calculating Exp(Row)&lt;/h2>
&lt;p>We make a function to compute &lt;code>exp(Q@K.T)&lt;/code>. CUDA has functional primitives for exponentiation, including &lt;code>exp&lt;/code>, &lt;code>expf&lt;/code>, and &lt;code>__expf&lt;/code>, but we won&amp;rsquo;t use any of them. Since LLMs can tolerate a decent error margin, we can use the faster &lt;code>exp2f&lt;/code> primitive instead. Most NVIDIA GPUs have native hardware support for power-of-two exponentiation, which is often 10-15% faster than &lt;code>expf&lt;/code>.&lt;/p>
&lt;blockquote>
&lt;p>&lt;code>__expf()&lt;/code> is a lower-precision and significantly faster version of &lt;code>expf()&lt;/code>, and might even use &lt;code>exp2f&lt;/code> under the hood. But since the FA2 source code uses &lt;code>exp2f&lt;/code> directly, we&amp;rsquo;ll do the same.&lt;/p>
&lt;/blockquote>
&lt;p>To use &lt;code>exp2f(x)&lt;/code>, we have to scale &lt;code>x&lt;/code> by &lt;code>log2(e)&lt;/code>, since $2^{\log_2(e)x} = \exp(x)$.&lt;/p>
&lt;h3 id="softmax_scale_log2-dont-forget-the-scaling-factor">&lt;code>softmax_scale_log2&lt;/code>: Don&amp;rsquo;t Forget the Scaling Factor!&lt;/h3>
&lt;p>Instead of computing $\log_2(e)$, we can store it as a float constant. This saves us from computing it again and again, which would erode our performance gain. Furthermore, we&amp;rsquo;ve completely neglected the denominator scale factor $1/\sqrt{d_h}$. Since this applies at any point we compute $QK^T$, we can simply fold it into this scale factor, since it&amp;rsquo;s constant for the entire kernel. During kernel dispatch, we just pass in this precomputed scale factor as an argument.&lt;/p>
&lt;h3 id="exp-loop">Exp Loop&lt;/h3>
&lt;p>For the exp loop, we iterate over the rows again, scaling the max and each value by the log2 scale factor before &lt;code>exp2f&lt;/code>:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="k">typename&lt;/span> &lt;span class="n">Engine0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Layout0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Engine1&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">typename&lt;/span> &lt;span class="n">Layout1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">__forceinline__&lt;/span> &lt;span class="n">__device__&lt;/span> &lt;span class="kt">void&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">scale_apply_exp2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Tensor&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Engine0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Layout0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Tensor&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Engine1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Layout1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="k">const&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">max&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">const&lt;/span> &lt;span class="kt">float&lt;/span> &lt;span class="n">softmax_scale_log2&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">#pragma unroll
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">r&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">r&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="n">r&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="kt">float&lt;/span> &lt;span class="n">adj_max&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">max&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">r&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">softmax_scale_log2&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">#pragma unroll
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">c&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">c&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="n">c&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// compiler often does a*b + c in one instruction
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="c1">// called FMA (fused multiply-add)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">r&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">c&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">exp2f&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">r&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">c&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">softmax_scale_log2&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="n">adj_max&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="full-softmax-call-rescale">Full Softmax Call: Rescale&lt;/h2>
&lt;p>It&amp;rsquo;s time to piece all our previous functions together in one function call that each main loop iteration will invoke after computing $QK^T$. It does the following steps:&lt;/p>
&lt;ul>
&lt;li>Reshape our accumulators to the expected layout&lt;/li>
&lt;li>Compute the &lt;code>max&lt;/code>&lt;/li>
&lt;li>Apply the correction to the previous output and expsum&lt;/li>
&lt;li>Call &lt;code>scale_apply_exp2&lt;/code>&lt;/li>
&lt;li>Reduce the sum&lt;/li>
&lt;/ul>
&lt;p>In the &lt;code>Softmax{}&lt;/code> struct we declared at the beginning of &lt;a class="link" href="#online-softmax" >this section&lt;/a>, we add the following function:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="kt">bool&lt;/span> &lt;span class="n">Is_first&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Tensor0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Tensor1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">__device__&lt;/span> &lt;span class="n">__forceinline__&lt;/span> &lt;span class="kt">void&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">softmax_rescale_o&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Tensor0&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">acc_s&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1">// (MMA, MMA_M, MMA_N) score block, fp32
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">Tensor1&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">acc_o&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1">// (MMA, MMA_M, MMA_K) output acc, fp32
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="kt">float&lt;/span> &lt;span class="n">softmax_scale_log2&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>We first reshape &lt;code>acc_s&lt;/code> to our row/column view. If this is the first rescale call, we don&amp;rsquo;t have to do any rescaling, so we template this call on the boolean &lt;code>Is_first&lt;/code>.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">softmax_rescale_o&lt;/span>&lt;span class="p">(...)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Tensor&lt;/span> &lt;span class="n">scores&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_s&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">convert_layout_rowcol&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_s&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">layout&lt;/span>&lt;span class="p">()));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">Is_first&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span> &lt;span class="c1">// first block, no prevs
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">FLASH&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">reduce_max&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="nb">true&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">scores&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">row_max&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">FLASH&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">scale_apply_exp2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">scores&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">row_max&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">softmax_scale_log2&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">FLASH&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">reduce_sum&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="nb">true&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">scores&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">row_sum&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">...&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>For the standard call, we first reduce the max and apply the correction on the previous output (which we reshape as well). Then we call &lt;code>scale_apply_exp2&lt;/code> and reduce the sum.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">softmax_rescale_o&lt;/span>&lt;span class="p">(...)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">...&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">else&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">FLASH&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">reduce_max&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="nb">false&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">scores&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">row_max&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Tensor&lt;/span> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_o&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">FLASH&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">convert_layout_rowcol&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_o&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">layout&lt;/span>&lt;span class="p">()));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// apply correction to output
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="cp">#pragma unroll
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">r&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">r&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">output&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="n">r&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// exp(m_old-m_new)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="kt">float&lt;/span> &lt;span class="n">correction&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">exp2f&lt;/span>&lt;span class="p">((&lt;/span>&lt;span class="n">row_max_old&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">r&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="n">row_max&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">r&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">softmax_scale_log2&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">row_sum&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">r&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">*=&lt;/span> &lt;span class="n">correction&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">#pragma unroll
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">c&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">c&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">output&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="n">c&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">r&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">c&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">*=&lt;/span> &lt;span class="n">correction&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// exp2(scores-m_new)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">FLASH&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">scale_apply_exp2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">scores&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">row_max&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">softmax_scale_log2&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// sum(scores_exp), per thread, full reduce at end of main kernel
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">FLASH&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">reduce_sum&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="nb">false&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">scores&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">row_sum&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Everything is coming together nicely now. One simple reshape, and everything is what you always imagined CUDA coding to be. If only it could always be this easy. All we need now is the final normalization at the end, where we compute our final expsum denominator and perform one last output scale.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="k">typename&lt;/span> &lt;span class="n">Tensor0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">__device__&lt;/span> &lt;span class="n">__forceinline__&lt;/span> &lt;span class="kt">void&lt;/span> &lt;span class="n">normalize_softmax&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Tensor0&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">acc_o&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// final expsum reduce
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">quad_allreduce_&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">row_sum&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">row_sum&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">SumOp&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Tensor&lt;/span> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_o&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">FLASH&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">convert_layout_rowcol&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_o&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">layout&lt;/span>&lt;span class="p">()));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">r&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">r&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">output&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="n">r&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="kt">float&lt;/span> &lt;span class="n">row_sum_i&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mf">1.f&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">row_sum&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">r&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">c&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">c&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">output&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="n">c&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">r&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">c&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">*=&lt;/span> &lt;span class="n">row_sum_i&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>This function is called in the epilogue after the main loop, before we store the output back to GMEM. Overall, the softmax step is not particularly complicated. There&amp;rsquo;s some funky, confusing layout reshaping and learning warp-reduce primitives, but everything else pieces together nicely once you understand the MMA thread layout.&lt;/p>
&lt;h1 id="putting-it-all-together">Putting It All Together&lt;/h1>
&lt;p>Let&amp;rsquo;s do some bookkeeping and see where the softmax call goes and finally implement the $O=SV$ GEMM.&lt;/p>
&lt;h2 id="creating-output-fragment-acc_o-and-softmax-struct">Creating Output Fragment &lt;code>acc_o&lt;/code> and Softmax Struct&lt;/h2>
&lt;p>Before the main loop, we first create our output fragment &lt;code>acc_o&lt;/code>, which is the actual output we rescale in Softmax and write to our output tensor. It has shape &lt;code>(kBlockM, kHeadDim)&lt;/code> and is a C-fragment just like &lt;code>acc_s&lt;/code>. Similarly, we initialize it with 0.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">acc_o&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">partition_fragment_C&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">tiled_mma&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockM&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kHeadDim&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// fill with 0
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">clear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_o&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Next, we initialize our softmax struct so we can call it in the main loop. We compute the &lt;code>kNRows&lt;/code> shape based on the MMA shape of &lt;code>acc_o&lt;/code> that we covered in &lt;a class="link" href="#fragment-reshape" >fragment reshape&lt;/a>.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;span class="lnt">8
&lt;/span>&lt;span class="lnt">9
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="p">...&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">clear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_o&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// initialize softmax, acc_o: (MMA, MMA_M, MMA_HEAD_DIM)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// rows is 2*MMA_M dim, 2 rows per thread for each MMA tile
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">FLASH&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">Softmax&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_o&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="n">softmax&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="cm">/* main loop */&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">...&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="the-softmax-rescale-call">The Softmax Rescale Call&lt;/h2>
&lt;p>As described back in &lt;a class="link" href="#main-loop" >the main loop&lt;/a>, the softmax rescale happens right after $QK^T$, the V-block sync, and the K-block async issue. We create an if branch to handle whether the current block is the first block or not.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">gemm_QK&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// wait for V
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">cp_async_wait&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">__syncthreads&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// next K block prefetch
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">if&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">nblock&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">nBlocksN&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span> &lt;span class="c1">// not last block
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">gmem_tiled_copy_QKV&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tKgK&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">nblock&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">tKsK&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">cp_async_fence&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// 2. P=softmax(S)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">if&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">nblock&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">softmax&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="k">template&lt;/span> &lt;span class="n">softmax_rescale_o&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="cm">/*Is_first*/&lt;/span> &lt;span class="nb">true&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">acc_s&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">acc_o&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">scale_softmax_log2&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span> &lt;span class="k">else&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">softmax&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="k">template&lt;/span> &lt;span class="n">softmax_rescale_o&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="cm">/*Is_first*/&lt;/span> &lt;span class="nb">false&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">acc_s&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">acc_o&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">scale_softmax_log2&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="mma-loop-sv-gemm">MMA Loop: SV GEMM&lt;/h2>
&lt;blockquote>
&lt;p>&lt;strong>Source:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/utils.cuh" target="_blank" rel="noopener"
>&lt;code>utils.cuh&lt;/code>&lt;/a> (&lt;code>gemm_rs&lt;/code>), &lt;code>rs&lt;/code>: right side only.&lt;/p>
&lt;/blockquote>
&lt;p>Next comes $SV$. This GEMM is almost exactly the same as the one before, except that S is already in the registers. This means we only need to deal with V SMEM copies in the GEMM loop this time.&lt;/p>
&lt;p>However, even though &lt;code>acc_s&lt;/code> is in registers, we have to do two transformations before our GEMM:&lt;/p>
&lt;ol>
&lt;li>It is currently an fp32 accumulator. We have to convert it back to fp16 before the MMA.&lt;/li>
&lt;li>&lt;code>acc_s&lt;/code> is also currently stored as a &lt;code>fragment_C&lt;/code> since it was an accumulator for $QK^T$. We need to reshape it as a &lt;code>fragment_A&lt;/code> to pass it to the $SV$ GEMM. Therefore, we will do a frag-C to frag-A reshape, similar to our &lt;a class="link" href="#fragment-reshape" >row-col reshape util&lt;/a> for Softmax.&lt;/li>
&lt;/ol>
&lt;h3 id="fp32-fp16-conversion">FP32-&amp;gt;FP16 Conversion&lt;/h3>
&lt;p>CUTLASS provides us the numerical conversion operator &lt;code>cutlass::NumericArrayConverter&amp;lt;To_type, From_type, numel&amp;gt;&lt;/code> to help us do this conversion. This function expects the tensor to be &amp;ldquo;contiguous&amp;rdquo; (which would be a true requirement for GMEM/SMEM) &amp;ndash; but as we learned, contiguity doesn&amp;rsquo;t exist for registers. Therefore, we have to force our tensor into the standard column-major layout for this op to work:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="k">typename&lt;/span> &lt;span class="n">To_type&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Engine&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">__forceinline__&lt;/span> &lt;span class="n">__device__&lt;/span> &lt;span class="k">auto&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">convert_type&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Tensor&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Engine&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="k">const&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// Trick to grab the cute float type from source tensor
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="k">using&lt;/span> &lt;span class="n">From_type&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="k">typename&lt;/span> &lt;span class="n">Engine&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">value_type&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// number of elements
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="k">constexpr&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">numel&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="k">decltype&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tensor&lt;/span>&lt;span class="p">))&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">value&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">cutlass&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">NumericArrayConverter&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">To_type&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">From_type&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">numel&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="n">convert_op&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// HACK: force frag to &amp;#34;contiguous&amp;#34; layout
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="k">auto&lt;/span> &lt;span class="n">frag&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">convert_op&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="k">reinterpret_cast&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="k">const&lt;/span> &lt;span class="n">cutlass&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">Array&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">From_type&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">numel&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="o">*&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">tensor&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">()));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="nf">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_rmem_ptr&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">To_type&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">frag&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">tensor&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">layout&lt;/span>&lt;span class="p">());&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="frag-c-to-frag-a-reshape">Frag-C to Frag-A Reshape&lt;/h3>
&lt;p>Now for the reshape. If you recall from our &lt;a class="link" href="#tiled-mma" >MMA Atom thread layout&lt;/a>, each A-fragment holds 8 values (16x16) while our C-fragment holds 8 values, but across two 16x8 tiles. To remove any guesswork, we can just print the fragment shapes:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-fallback" data-lang="fallback">&lt;span class="line">&lt;span class="cl">A fragment layout: ((_2,_2,_2),_1,_1):((_1,_2,_4),_0,_0)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">C fragment layout: ((_2,_2),_1,_1):((_1,_2),_0,_0)
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>The second and third dimensions here are &lt;code>MMA_M&lt;/code> and &lt;code>MMA_N&lt;/code> (tile width/height dividing into M, N). Since the $SV$ MMA concatenates along N, we just need to combine each pair of N blocks.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">acc_s&lt;/span> &lt;span class="n">C&lt;/span> &lt;span class="n">shape&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="p">((&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">MMA_M&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">MMA_N&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">((&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">STRIDE_M&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">STRIDE_N&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Concatenate along N, each N dim halves, stride doubles&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">acc_s&lt;/span> &lt;span class="n">A&lt;/span> &lt;span class="n">shape&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="p">((&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">MMA_M&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">MMA_N&lt;/span>&lt;span class="o">/&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">((&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">STRIDE_M&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">STRIDE_N&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>We can write this up pretty easily:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="k">typename&lt;/span> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">__forceinline__&lt;/span> &lt;span class="n">__device__&lt;/span> &lt;span class="k">auto&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">convert_c_frag_to_a_frag&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Layout&lt;/span> &lt;span class="k">const&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">acc_layout&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">auto&lt;/span> &lt;span class="n">s&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">acc_layout&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">shape&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">auto&lt;/span> &lt;span class="n">stride&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">acc_layout&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">stride&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">auto&lt;/span> &lt;span class="n">shape_n&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">s&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">_2&lt;/span>&lt;span class="p">{}),&lt;/span> &lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">s&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">s&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">_2&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">auto&lt;/span> &lt;span class="n">stride_n&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_stride&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_stride&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">stride&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">stride&lt;/span>&lt;span class="p">)),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">stride&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">stride&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">_2&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="nf">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">shape_n&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride_n&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>This code looks pretty awful, but the logic is simple. Similar to the row-col reshape, the FA2 source opts to use &lt;code>logical_divide&lt;/code> along the N-axis, which is easier to read and accomplishes the same thing:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;span class="lnt">8
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// _ to keep full dim
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="k">using&lt;/span> &lt;span class="n">_&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Underscore&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">auto&lt;/span> &lt;span class="n">l&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">logical_divide&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_layout&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="nf">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_layout&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">l&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">l&lt;/span>&lt;span class="p">)),&lt;/span> &lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">l&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">get&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">l&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="gemm-rs-loop">GEMM RS Loop&lt;/h3>
&lt;p>All we need to do now is convert &lt;code>acc_s&lt;/code> and reshape it. Then, we can simply copy our old GEMM loop but remove the A-frag copy lines.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="p">...&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">softmax_stuff&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">acc_s_fp16&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">FLASH&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">convert_type&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">half_t&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_s&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// reshape to A fragment for next matmul
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tOrP&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_s_fp16&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">FLASH&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">convert_c_frag_to_a_frag&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_s_fp16&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">layout&lt;/span>&lt;span class="p">()));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tXrV&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">smem_thr_copy_V&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">retile_D&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tOrVt&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">smem_tiled_copy_V&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tOsVt&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_0&lt;/span>&lt;span class="p">{}),&lt;/span> &lt;span class="n">tXrV&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_0&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">#pragma unroll
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="cp">&lt;/span>&lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="kt">int&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tOrP&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="n">i&lt;/span>&lt;span class="o">++&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// prefetch next block
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="k">if&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">i&lt;/span> &lt;span class="o">&amp;lt;&lt;/span> &lt;span class="n">size&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tCrV&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">smem_tiled_copy_V&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tOsVt&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">tXrV&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">gemm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tiled_mma&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tOrP&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">i&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">tXrV&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">i&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">acc_o&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="final-softmax-normalization">Final Softmax normalization&lt;/h2>
&lt;p>After our main loop concludes, we compute our final softmax normalization. It happens right outside the loop, and there is no need for a &lt;code>__syncthreads()&lt;/code> call since the same threads own the same data until our final output SMEM-&amp;gt;GMEM copy that we&amp;rsquo;ll cover next.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="cm">/* main loop */&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">...&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// final o scaling
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">softmax&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">normalize_softmax&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_o&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h1 id="epilogue-output-gmem">Epilogue: Output-&amp;gt;GMEM&lt;/h1>
&lt;p>We now have our output stored in fragments across all the warps, and we want to write them back to &lt;code>o_ptr&lt;/code> in GMEM. The optimal way to perform the write-back is:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-fallback" data-lang="fallback">&lt;span class="line">&lt;span class="cl">Registers-&amp;gt;SMEM-&amp;gt;Registers-&amp;gt;GMEM
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>You might be wondering: if we&amp;rsquo;re going to go from SMEM back to registers anyway, why don&amp;rsquo;t we just write back directly from the fragments? The answer is threefold:&lt;/p>
&lt;ol>
&lt;li>&lt;strong>No SMEM-&amp;gt;GMEM instructions&lt;/strong>: Ampere has the nice async GMEM-&amp;gt;SMEM pipeline, but no direct SMEM-&amp;gt;GMEM pipeline. Therefore, we have to stage the SMEM blocks in registers before writing back to HBM.&lt;/li>
&lt;li>&lt;strong>Vectorization&lt;/strong>: The fragments are stored as per-thread shapes, scattered as halfs across all the registers. To write back to GMEM directly, each thread would have to write one fp32 at a time across a bunch of scattered memory addresses. Given what we know about the memory bus and vectorization, this is excruciatingly inefficient and slow. By staging through SMEM, each thread can write a full 128-bit block in one instruction, like before.&lt;/li>
&lt;li>&lt;strong>Coalescing&lt;/strong>: Furthermore, we can group all 32 threads into a contiguous block to coalesce the store into a 512-byte memory transaction, like before.&lt;/li>
&lt;/ol>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: Vectorization is per-thread; coalescing is across threads (per-warp).&lt;/p>
&lt;/blockquote>
&lt;p>Compared to a terrible amount of uncoalesced and unvectorized stores, the SMEM staging is pretty much free relative to the efficiency gains from the optimized stores. As before, the Registers-&amp;gt;SMEM and SMEM-&amp;gt;Registers-&amp;gt;GMEM steps each require their own &lt;code>Tiled_Copy&lt;/code>, but it should be quicker to figure out what they should be this time around. Before we begin the copy, we should convert the output back to fp16 using the same &lt;code>convert_type()&lt;/code> we used for &lt;code>acc_s&lt;/code> before the &lt;code>O = SV&lt;/code> MMA.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// Epilogue
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// final output scale
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">softmax&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">normalize_softmax&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_o&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// convert o from float back to fp16
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="n">o_fp16&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">FLASH&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">convert_type&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">half_t&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">acc_o&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="the-staged-output-copy-registers--smem--registers--gmem">The Staged Output Copy: Registers → SMEM → Registers → GMEM&lt;/h2>
&lt;p>The output is currently scattered across each thread&amp;rsquo;s registers. Per the &lt;a class="link" href="#epilogue-output-gmem" >opening of this section&lt;/a>, we have to stage it through SMEM to land a coalesced, vectorized GMEM write. Two tiled copies, executed in sequence &amp;ndash; registers → SMEM, then SMEM → registers → GMEM.&lt;/p>
&lt;h3 id="registers--smem">Registers → SMEM&lt;/h3>
&lt;p>We can now begin our register-&amp;gt;SMEM write. Since we never sliced a portion of SMEM for O, we can simply reuse Q&amp;rsquo;s SMEM portion; it has the exact same shape and size as O, and it isn&amp;rsquo;t being used for anything anymore. We can also reuse its layout, since the write access pattern has the same bank-conflict problem as the read, so our swizzled layout from before is perfect.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">sO&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sQ&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">SmemLayoutQ&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>The next step is to define the tiled copy. Even though Ampere supports the &lt;code>SM75&lt;/code> (Turing) &lt;code>LDSM&lt;/code> instructions for loading MMA fragments, there is no analogous store instruction. The &lt;code>STSM&lt;/code> instructions were introduced for the H100 Hopper architecture (&lt;code>SM90&lt;/code>), but only God knows why they weren&amp;rsquo;t introduced earlier. Instead, we can just do a typical vectorized copy back to SMEM.&lt;/p>
&lt;p>The FA2 source code takes the lazy route and uses:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">using&lt;/span> &lt;span class="n">SmemCopyAtomO&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Copy_Atom&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">AutoVectorizingCopyWithAssumedAlignment&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="mi">128&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">half_t&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Why is this lazy? Because &lt;code>AutoVectorizing...&lt;/code> just tells the compiler to find the largest vectorized chunk it can store in one go according to the tiled MMA and fragment layouts. Since 128-bit loads/stores are the maximum size, we&amp;rsquo;re essentially telling the kernel: hey, you optimize it for me. Reading this bit of FA2 source code can lead you to think a 128-bit vectorized store is possible here, but it unfortunately is not. Let&amp;rsquo;s examine why:&lt;/p>
&lt;p>Recall that in the output fragment, each thread holds 4 values per tile &amp;ndash; thread 0 holds &lt;code>(0, 0), (0, 1), (8, 0), (8, 1)&lt;/code> (see &lt;a class="link" href="#thread-reduce" >thread reduce&lt;/a>). And as we established in &lt;a class="link" href="#registers-arent-memory" >Registers Aren&amp;rsquo;t Memory&lt;/a>, the fragment&amp;rsquo;s &amp;ldquo;column-major&amp;rdquo; layout is fiction; there is no physical column-major memory underneath. So the vectorization here is bounded by what each thread can write contiguously to the &lt;em>output SMEM&lt;/em>, not by what the fragment layout looks like. The hardware PTX &lt;code>st.shared.v2.b16&lt;/code> takes any two registers with fp16s and stores them at one fp32 address. Each thread holds 2 contiguous halfs, so the max vectorization is 32 bits. This is a hardware limit, and we can&amp;rsquo;t optimize further. Since the SMEM-&amp;gt;GMEM step is fully vectorized, this isn&amp;rsquo;t a meaningful bottleneck.&lt;/p>
&lt;p>If you want to be exact, you can replace the copy atom above with the one below for superior clarity:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">using&lt;/span> &lt;span class="n">SmemCopyAtomO&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Copy_Atom&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">UniversalCopy&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="kt">uint32_t&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">half_t&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>This is technically more accurate than what the source code specifies. You can check by compiling the full kernel with fixed-size universal copies until it compiles&amp;ndash;if it&amp;rsquo;s not compatible, it&amp;rsquo;ll throw an error. Other ways to verify include printing the per-thread shapes to see the strides, looking at the raw PTX instructions, or, unfortunately, using your brain.&lt;/p>
&lt;h4 id="tiled-copy-1">Tiled Copy&lt;/h4>
&lt;p>Let&amp;rsquo;s make our tiled copy object. To let CuTe know we&amp;rsquo;re working with an output MMA thread layout, we can use &lt;code>make_tiled_copy_C()&lt;/code>&amp;ndash;the &lt;code>C&lt;/code> version instead of the &lt;code>A/B&lt;/code> we used for Q, K, and V.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">smem_tiled_copy_O&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_tiled_copy_C&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">SmemCopyAtomO&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">tiled_mma&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">smem_thr_copy_O&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">smem_tiled_copy_O&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">get_thread_slice&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tid&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>This time we have to retile the registers (now the source) to fit this new copy atom, and partition the SMEM destination &amp;ndash; per the &lt;a class="link" href="#partition-vs-retile" >partition vs. retile rule&lt;/a>, it&amp;rsquo;s &lt;code>retile_S&lt;/code> for the register source and &lt;code>partition_D&lt;/code> for the SMEM destination. Then, we can issue our copy.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// retile_S this time, since it&amp;#39;s the source
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">trO&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">smem_thr_copy_O&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">retile_S&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">o_fp16&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">auto&lt;/span> &lt;span class="n">tsO&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">smem_thr_copy_O&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_D&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sO&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// copy, and we&amp;#39;re done!
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">smem_tiled_copy_O&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">trO&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tsO&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h3 id="smem--registers--gmem">SMEM → Registers → GMEM&lt;/h3>
&lt;p>We create the GMEM output tile the same way we created the Q, K, V source tiles. It has exactly the same layout as Q. The only difference is that it&amp;rsquo;s not &lt;code>const&lt;/code>, since we&amp;rsquo;re modifying its contents:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;span class="lnt">8
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">mO&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">make_gmem_ptr&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="k">reinterpret_cast&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">half_t&lt;/span> &lt;span class="o">*&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">o_ptr&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">batch_idx&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">o_batch_stride&lt;/span> &lt;span class="o">+&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">head_idx&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">o_head_stride&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">seqlen_q&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">head_dim&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_stride&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">q_row_stride&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_1&lt;/span>&lt;span class="p">{}));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">gO&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">local_tile&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">mO&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">make_shape&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kBlockM&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">Int&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">kHeadDim&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">{}),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">make_coord&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m_block&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h4 id="smem-gmem-tiled-copy">SMEM-&amp;gt;GMEM Tiled Copy&lt;/h4>
&lt;p>Our tiled copy atom is pretty much exactly the same as the one for Q, except we use a standard synchronous 128-bit copy atom instead of the &lt;code>cp.async&lt;/code> we used for GMEM-&amp;gt;SMEM copies. This time, our 128-bit &lt;code>AutoVectorizing&lt;/code> cheat is fine, since we&amp;rsquo;re doing full 128-bit vectorized loads, although you can explicitly declare a 128-bit &lt;code>UniversalCopy&lt;/code> for clarity. As with the GMEM-&amp;gt;SMEM load, we still benefit from memory coalescing because the blocks written &lt;em>TO&lt;/em> GMEM are contiguous in memory. Even though blocks may not be contiguous in SMEM, they are when we store to GMEM.&lt;/p>
&lt;blockquote>
&lt;p>&lt;strong>Note&lt;/strong>: Remember, memory coalescing is a &lt;em>GMEM optimization&lt;/em>. The swizzle is our valet attendant&amp;ndash;it stores and retrieves our car on the thread&amp;rsquo;s behalf, and we don&amp;rsquo;t really care &lt;em>where&lt;/em> it puts it. Once each thread retrieves its &amp;ldquo;car&amp;rdquo;, they each park them contiguously in GMEM, which is all that matters for coalescing.&lt;/p>
&lt;/blockquote>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;span class="lnt">7
&lt;/span>&lt;span class="lnt">8
&lt;/span>&lt;span class="lnt">9
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// can reuse the 128-bit auto vectorized SMEM copy
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// if you used 32-bit universal, then you&amp;#39;ll have to redefine it here
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">using&lt;/span> &lt;span class="n">SmemCopyAtomO&lt;/span> &lt;span class="o">=&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Copy_Atom&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">UniversalCopy&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">uint128_t&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">half_t&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// Copy_Atom&amp;lt;AutoVectorizingCopyWithAssumedAlignment&amp;lt;128&amp;gt;, cute::half_t&amp;gt;;
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// same gmem layout as QKV
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">gmem_tiled_copy_O&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_tiled_copy&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">SmemCopyAtomO&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">GmemLayout&lt;/span>&lt;span class="p">{},&lt;/span> &lt;span class="n">Layout&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Shape&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">_1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">_8&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">{});&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Now we can create our thread slice and partition our source and destination memory.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// thread Output _ Output
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="k">auto&lt;/span> &lt;span class="n">gmem_thr_copy_O&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">gmem_tiled_copy_O&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">get_thread_slice&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tid&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tOsO&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">gmem_thr_copy_O&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_S&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sO&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tOgO&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">gmem_thr_copy_O&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">partition_D&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">gO&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>Since there is no direct SMEM-&amp;gt;GMEM instruction, we cannot simply &lt;code>copy(..., tOsO, tOgO)&lt;/code>. Instead, we stage through a register of the same shape. This is pretty much free since we aren&amp;rsquo;t using any registers at this stage. The tiled copy doesn&amp;rsquo;t know or care whether our source/dest is GMEM, SMEM, or a register, so we just make a &lt;code>tOgO&lt;/code> clone and issue two copies using the same &lt;code>Tiled_Copy&lt;/code>:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;span class="lnt">3
&lt;/span>&lt;span class="lnt">4
&lt;/span>&lt;span class="lnt">5
&lt;/span>&lt;span class="lnt">6
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// register buffer
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="n">tOrO&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">make_fragment_like&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">tOgO&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// smem-&amp;gt;registers
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">gmem_tiled_copy_O&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tOsO&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tOrO&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// registers-&amp;gt;gmem
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="n">cute&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">copy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">gmem_tiled_copy_O&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tOrO&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">tOgO&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>OMG! We&amp;rsquo;re&amp;hellip;&lt;/p>
&lt;h4 id="sync-threads">Sync Threads&lt;/h4>
&lt;p>Not quite yet. We have one last dance. There&amp;rsquo;s a slight bug in our epilogue as-is. Between the register-&amp;gt;SMEM and SMEM-&amp;gt;GMEM copy, threads control different parts of SMEM. We set up async waits earlier, but since we&amp;rsquo;re synchronous for the O store, we simply have to call &lt;code>__syncthreads()&lt;/code> sometime between the two copy stages. The FA2 production code opts to put it right before the final two &lt;code>copy()&lt;/code> invocations to overlap the GMEM tiled copy setup with the register-&amp;gt;SMEM copy, but practically it probably doesn&amp;rsquo;t make much of a difference&amp;ndash;you can just put the sync right after the r-&amp;gt;S copy.&lt;/p>
&lt;h1 id="plumbing-params-launch-dispatch-kernel-traits">Plumbing: Params, Launch, Dispatch, Kernel Traits&lt;/h1>
&lt;blockquote>
&lt;p>&lt;strong>Source:&lt;/strong> &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/flash.h" target="_blank" rel="noopener"
>&lt;code>flash.h&lt;/code>&lt;/a>, &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/flash_fwd_launch_template.h" target="_blank" rel="noopener"
>&lt;code>flash_fwd_launch_template.h&lt;/code>&lt;/a>, &lt;a class="link" href="https://github.com/cloudui/cuda-triton/tree/main/cuda/flash_attn_cutlass" target="_blank" rel="noopener"
>&lt;code>flash_fwd_hdim{32,64,128}_fp16_sm80.cu&lt;/code>&lt;/a>, &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/flash_api.cu" target="_blank" rel="noopener"
>&lt;code>flash_api.cu&lt;/code>&lt;/a>, &lt;a class="link" href="https://github.com/cloudui/cuda-triton/blob/main/cuda/flash_attn_cutlass/kernel_traits.cuh" target="_blank" rel="noopener"
>&lt;code>kernel_traits.cuh&lt;/code>&lt;/a>&lt;/p>
&lt;/blockquote>
&lt;p>The kernel itself is done. We still have to deal with the typical CUDA dispatch to make it actually runnable. We&amp;rsquo;re going to speedrun the rest of these files since they are not relevant to the algorithm itself, but I&amp;rsquo;ll briefly explain what remaining files we need before we can actually call FA2 from something like PyTorch. It&amp;rsquo;s not too important to comb through every line &amp;ndash; most of the time, you&amp;rsquo;ll just copy the boilerplate from some old kernel you made and modify some variables.&lt;/p>
&lt;h2 id="kernel-traits">Kernel Traits&lt;/h2>
&lt;p>As we mentioned in &lt;a class="link" href="#code-layout-the-repo" >the code layout section&lt;/a>, every tiled MMA, tiled copy, and layout is defined here. It&amp;rsquo;s the atomic backbone that we use for any operation in &lt;code>flash_fwd_kernel.h&lt;/code>. Anything we wrote with &lt;code>using&lt;/code> is simply defined here, as well as any constants. We combined the code together in blocks in this blog for reading purposes; in the actual kernel, any type references will look more like this:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">typename&lt;/span> &lt;span class="n">Traits&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">GmemTiledCopyQKV&lt;/span> &lt;span class="n">gmem_tiled_copy_QKV&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;h2 id="params-struct">Params Struct&lt;/h2>
&lt;p>Everything the kernel needs is funneled through a simple data struct, &lt;code>Flash_fwd_params&lt;/code>. Pointers for Q, K, V, and O; dimensions; per-tensor strides; and the precomputed softmax scale. The struct lives in &lt;code>flash.h&lt;/code> because both the kernel side (&lt;code>flash_fwd_kernel.h&lt;/code>) and the host side (&lt;code>flash_api.cu&lt;/code>) include it.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">struct&lt;/span> &lt;span class="nc">Flash_fwd_params&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">const&lt;/span> &lt;span class="kt">void&lt;/span> &lt;span class="o">*&lt;/span>&lt;span class="n">__restrict__&lt;/span> &lt;span class="n">q_ptr&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">*&lt;/span>&lt;span class="n">__restrict__&lt;/span> &lt;span class="n">k_ptr&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">*&lt;/span>&lt;span class="n">__restrict__&lt;/span> &lt;span class="n">v_ptr&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="kt">void&lt;/span> &lt;span class="o">*&lt;/span>&lt;span class="n">__restrict__&lt;/span> &lt;span class="n">o_ptr&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// unused, i just copied it over; only used in bwd pass
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="kt">float&lt;/span> &lt;span class="o">*&lt;/span>&lt;span class="n">__restrict__&lt;/span> &lt;span class="n">softmax_lse_ptr&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="c1">// (batch, num_heads, seqlen_q)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="kt">int&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">seqlen_q&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">seqlen_k&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_heads&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_heads_k&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">head_dim&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1">// PyTorch passes strides in elements (not bytes)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">q_batch_stride&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">q_row_stride&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">q_head_stride&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="kt">int&lt;/span> &lt;span class="n">k_batch_stride&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">k_row_stride&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">k_head_stride&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="kt">int&lt;/span> &lt;span class="n">v_batch_stride&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">v_row_stride&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">v_head_stride&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="kt">int&lt;/span> &lt;span class="n">o_batch_stride&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">o_row_stride&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">o_head_stride&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="kt">float&lt;/span> &lt;span class="n">scale_softmax&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="c1">// 1/sqrt(d_h)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="kt">float&lt;/span> &lt;span class="n">scale_softmax_log2&lt;/span>&lt;span class="p">;&lt;/span> &lt;span class="c1">// 1/sqrt(d_h) * log2(e), for exp2()
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="p">};&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;ul>
&lt;li>The &lt;strong>&lt;code>__restrict__&lt;/code>&lt;/strong> keyword tells the compiler that the pointers don&amp;rsquo;t alias/access the same memory, which lets it optimize the code more aggressively.&lt;/li>
&lt;li>Strides are in # of elements, not bytes.&lt;/li>
&lt;li>&lt;strong>&lt;code>scale_softmax_log2&lt;/code>&lt;/strong> is precomputed on the host so we don&amp;rsquo;t have to compute &lt;code>log2(e)&lt;/code> inside the kernel&amp;rsquo;s inner loop (see &lt;a class="link" href="#softmax_scale_log2-dont-forget-the-scaling-factor" >the softmax scaling section&lt;/a>).&lt;/li>
&lt;/ul>
&lt;h2 id="launch-template">Launch Template&lt;/h2>
&lt;p>This is the host-side function in &lt;code>flash_fwd_launch_template.h&lt;/code> that picks the grid shape, configures SMEM, and launches the kernel. We need one template per &lt;code>Traits&lt;/code> and one runtime entry per head_dim:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="k">template&lt;/span> &lt;span class="o">&amp;lt;&lt;/span>&lt;span class="k">typename&lt;/span> &lt;span class="n">Traits&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kt">void&lt;/span> &lt;span class="n">run_flash_fwd&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Flash_fwd_params&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">cudaStream_t&lt;/span> &lt;span class="n">stream&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">constexpr&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">kBlockM&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Traits&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">kBlockM&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">constexpr&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">smem_size&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Traits&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">kSmemSize&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">const&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">num_m_blocks&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">seqlen_q&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="n">kBlockM&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">kBlockM&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">dim3&lt;/span> &lt;span class="nf">grid&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">num_m_blocks&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">batch_size&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">num_heads&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">dim3&lt;/span> &lt;span class="nf">block&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Traits&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">kNThreads&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">auto&lt;/span> &lt;span class="n">kernel&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">FLASH&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">flash_fwd_kernel&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Traits&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">smem_size&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="mi">48&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">1024&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">cudaFuncSetAttribute&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">kernel&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">cudaFuncAttributeMaxDynamicSharedMemorySize&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">smem_size&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">kernel&lt;/span>&lt;span class="o">&amp;lt;&amp;lt;&amp;lt;&lt;/span>&lt;span class="n">grid&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">block&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">smem_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stream&lt;/span>&lt;span class="o">&amp;gt;&amp;gt;&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kr">inline&lt;/span> &lt;span class="kt">void&lt;/span> &lt;span class="nf">run_mha_fwd_hdim32&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Flash_fwd_params&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">p&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">cudaStream_t&lt;/span> &lt;span class="n">s&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span> &lt;span class="n">run_flash_fwd&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Traits_hdim32&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">p&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">s&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kr">inline&lt;/span> &lt;span class="kt">void&lt;/span> &lt;span class="nf">run_mha_fwd_hdim64&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Flash_fwd_params&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">p&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">cudaStream_t&lt;/span> &lt;span class="n">s&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span> &lt;span class="n">run_flash_fwd&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Traits_hdim64&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">p&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">s&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kr">inline&lt;/span> &lt;span class="kt">void&lt;/span> &lt;span class="nf">run_mha_fwd_hdim128&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Flash_fwd_params&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">p&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">cudaStream_t&lt;/span> &lt;span class="n">s&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span> &lt;span class="n">run_flash_fwd&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">Traits_hdim128&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">p&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">s&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;ul>
&lt;li>&lt;strong>Grid shape.&lt;/strong> Two-dimensional: &lt;code>(num_m_blocks, batch * heads)&lt;/code>. The &lt;code>m_block&lt;/code> math is just a trick to compute &lt;code>ceil(seqlen_q / kBlockM)&lt;/code>; it&amp;rsquo;s overkill for us since we assume nice dims. The &lt;code>batch * heads&lt;/code> axis flattens the two independent batch and head dimensions into one &amp;ndash; the batch and head dims are completely independent so we can flatten them to make indexing simpler.&lt;/li>
&lt;li>&lt;strong>We make &lt;code>m_block&lt;/code> the &lt;em>first&lt;/em> grid dimension.&lt;/strong> This puts adjacent CTAs/thread blocks along the Q-tile axis, which helps L2 cache reuse since Q blocks of the same batch/head use the same K, V. If we did it along batch/head every CTA is independent and your cache will cry.&lt;/li>
&lt;li>&lt;strong>Block shape.&lt;/strong> &lt;code>kNThreads = kNWarps * 32&lt;/code> &amp;ndash; 128 threads for &lt;code>kNWarps = 4&lt;/code>. One thread block per Q-tile + fixed thread count.&lt;/li>
&lt;li>&lt;strong>&lt;code>cudaFuncSetAttribute&lt;/code> for extended SMEM.&lt;/strong> Ampere CTAs get 48 KB of SMEM by default. Our kernel needs more (the Q + 2*KV SMEM buffer easily exceeds 48 KB at &lt;code>hdim=128&lt;/code> &amp;ndash; see &lt;code>kSmemSize&lt;/code> in &lt;code>kernel_traits.cuh&lt;/code>). We have to tell the GPU we&amp;rsquo;re not simpletons working with tiny SMEM allocations; we ask for more by setting &lt;code>cudaFuncAttributeMaxDynamicSharedMemorySize&lt;/code>.&lt;/li>
&lt;li>The three &lt;code>inline&lt;/code> wrappers exist purely to make &lt;code>flash_api.cu&lt;/code>&amp;rsquo;s dispatch readable in the stdout.&lt;/li>
&lt;/ul>
&lt;h2 id="per-config-instantiations">Per-config Instantiations&lt;/h2>
&lt;p>If you look at the source repo, you&amp;rsquo;ll see that 80% of the files in the folder are just &lt;code>flash_fwd_hdim{32,64,128}_fp16_sm80.cu&lt;/code>:&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt">1
&lt;/span>&lt;span class="lnt">2
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="c1">// flash_fwd_hdim64_fp16_sm80.cu
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span>&lt;span class="cp">#include&lt;/span> &lt;span class="cpf">&amp;#34;flash_fwd_launch_template.h&amp;#34;&lt;/span>&lt;span class="cp">
&lt;/span>&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;p>They&amp;rsquo;re all empty boilerplate and are only there to prevent unnecessary recompilation. The point is to give &lt;code>nvcc&lt;/code> one &lt;code>.cu&lt;/code> per &lt;code>(head_dim, dtype, arch)&lt;/code> combination so the build system can compile them in parallel and incremental builds only rebuild what changed. The actual instantiation happens via the &lt;code>inline&lt;/code> functions in the header. If we bundled all three configs into one file, &lt;code>nvcc&lt;/code> will happily recompile them every build and force you to take an extra long bathroom break.&lt;/p>
&lt;h2 id="pytorch-binding">PyTorch Binding&lt;/h2>
&lt;blockquote>
&lt;p>&lt;strong>Source&lt;/strong>: &lt;code>flash_api.cu&lt;/code>&lt;/p>
&lt;/blockquote>
&lt;p>A nice wrapper so you can use the kernel in PyTorch.&lt;/p>
&lt;div class="highlight">&lt;div class="chroma">
&lt;table class="lntable">&lt;tr>&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code>&lt;span class="lnt"> 1
&lt;/span>&lt;span class="lnt"> 2
&lt;/span>&lt;span class="lnt"> 3
&lt;/span>&lt;span class="lnt"> 4
&lt;/span>&lt;span class="lnt"> 5
&lt;/span>&lt;span class="lnt"> 6
&lt;/span>&lt;span class="lnt"> 7
&lt;/span>&lt;span class="lnt"> 8
&lt;/span>&lt;span class="lnt"> 9
&lt;/span>&lt;span class="lnt">10
&lt;/span>&lt;span class="lnt">11
&lt;/span>&lt;span class="lnt">12
&lt;/span>&lt;span class="lnt">13
&lt;/span>&lt;span class="lnt">14
&lt;/span>&lt;span class="lnt">15
&lt;/span>&lt;span class="lnt">16
&lt;/span>&lt;span class="lnt">17
&lt;/span>&lt;span class="lnt">18
&lt;/span>&lt;span class="lnt">19
&lt;/span>&lt;span class="lnt">20
&lt;/span>&lt;span class="lnt">21
&lt;/span>&lt;span class="lnt">22
&lt;/span>&lt;span class="lnt">23
&lt;/span>&lt;span class="lnt">24
&lt;/span>&lt;span class="lnt">25
&lt;/span>&lt;span class="lnt">26
&lt;/span>&lt;span class="lnt">27
&lt;/span>&lt;span class="lnt">28
&lt;/span>&lt;span class="lnt">29
&lt;/span>&lt;span class="lnt">30
&lt;/span>&lt;span class="lnt">31
&lt;/span>&lt;span class="lnt">32
&lt;/span>&lt;span class="lnt">33
&lt;/span>&lt;span class="lnt">34
&lt;/span>&lt;span class="lnt">35
&lt;/span>&lt;/code>&lt;/pre>&lt;/td>
&lt;td class="lntd">
&lt;pre tabindex="0" class="chroma">&lt;code class="language-cpp" data-lang="cpp">&lt;span class="line">&lt;span class="cl">&lt;span class="n">std&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">vector&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">Tensor&lt;/span>&lt;span class="o">&amp;gt;&lt;/span> &lt;span class="n">mha_fwd&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">torch&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">q&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1">// (batch, seqlen_q, num_heads, head_dim)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">k&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="c1">// (batch, seqlen_k, num_heads_k, head_dim)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1">&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">Tensor&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">v&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">const&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">batch_size&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">q&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">const&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">seqlen_q&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">q&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">const&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">num_heads&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">q&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">const&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">head_dim&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">q&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">const&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">seqlen_k&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">k&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">const&lt;/span> &lt;span class="kt">int&lt;/span> &lt;span class="n">num_heads_k&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">k&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">auto&lt;/span> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">empty_like&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">q&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">auto&lt;/span> &lt;span class="n">softmax_lse&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">empty&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">{&lt;/span>&lt;span class="n">batch_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_heads&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">seqlen_q&lt;/span>&lt;span class="p">},&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">torch&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">dtype&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">kFloat32&lt;/span>&lt;span class="p">).&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">q&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">()));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Flash_fwd_params&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">q_ptr&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">q&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">data_ptr&lt;/span>&lt;span class="p">();&lt;/span> &lt;span class="cm">/* ...fill in pointers, dims, strides... */&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">scale_softmax&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mf">1.0f&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">std&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">sqrt&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="k">static_cast&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="kt">float&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">head_dim&lt;/span>&lt;span class="p">));&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">scale_softmax_log2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">params&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">scale_softmax&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="k">static_cast&lt;/span>&lt;span class="o">&amp;lt;&lt;/span>&lt;span class="kt">float&lt;/span>&lt;span class="o">&amp;gt;&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">M_LOG2E&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">auto&lt;/span> &lt;span class="n">stream&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">at&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">::&lt;/span>&lt;span class="n">getCurrentCUDAStream&lt;/span>&lt;span class="p">().&lt;/span>&lt;span class="n">stream&lt;/span>&lt;span class="p">();&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">switch&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">head_dim&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">case&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="o">:&lt;/span> &lt;span class="n">run_mha_fwd_hdim32&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stream&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="k">break&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">case&lt;/span> &lt;span class="mi">64&lt;/span>&lt;span class="o">:&lt;/span> &lt;span class="n">run_mha_fwd_hdim64&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stream&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="k">break&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">case&lt;/span> &lt;span class="mi">128&lt;/span>&lt;span class="o">:&lt;/span> &lt;span class="n">run_mha_fwd_hdim128&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">params&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stream&lt;/span>&lt;span class="p">);&lt;/span> &lt;span class="k">break&lt;/span>&lt;span class="p">;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">default&lt;/span>&lt;span class="o">:&lt;/span> &lt;span class="n">TORCH_CHECK&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="nb">false&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s">&amp;#34;unsupported head_dim &amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">head_dim&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="n">output&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">softmax_lse&lt;/span>&lt;span class="p">};&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">PYBIND11_MODULE&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">TORCH_EXTENSION_NAME&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">m&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">m&lt;/span>&lt;span class="p">.&lt;/span>&lt;span class="n">def&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s">&amp;#34;mha_fwd&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">&amp;amp;&lt;/span>&lt;span class="n">mha_fwd&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s">&amp;#34;FlashAttention-2 forward (CUDA)&amp;#34;&lt;/span>&lt;span class="p">);&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/td>&lt;/tr>&lt;/table>
&lt;/div>
&lt;/div>&lt;ul>
&lt;li>&lt;strong>Output allocation happens on the host.&lt;/strong> &lt;code>torch::empty_like(q)&lt;/code> and &lt;code>torch::empty({...})&lt;/code> are called here, not in the kernel. The kernel just writes into a pre-allocated buffer.&lt;/li>
&lt;li>&lt;strong>&lt;code>at::cuda::getCurrentCUDAStream()&lt;/code>&lt;/strong> gets the CUDA stream PyTorch is currently using. If your user has &lt;code>torch.cuda.stream(...)&lt;/code> set, your kernel runs on that stream and properly interleaves with everything else; if they don&amp;rsquo;t, you get the default stream. Use this and not &lt;code>cudaStreamDefault&lt;/code> &amp;ndash; the latter ignores user-set stream context and will make you cry.&lt;/li>
&lt;li>&lt;strong>&lt;code>PYBIND11_MODULE&lt;/code>&lt;/strong> is the standard pybind11 entry point. The macro registers &lt;code>mha_fwd&lt;/code> as a Python-callable name. Your python file can just &lt;code>import flash_attn_cutlass; flash_attn_cutlass.mha_fwd(q, k, v)&lt;/code>&lt;/li>
&lt;/ul>
&lt;p>That&amp;rsquo;s the entire boilerplate stack. You can call &lt;code>mha_fwd(q, k, v)&lt;/code> in Python; pybind11 hands the call to our C++ function; we fill in &lt;code>Flash_fwd_params&lt;/code> and pick a launcher by head_dim; the launcher computes the grid, enables extended SMEM, and launches the kernel. The &lt;code>Makefile&lt;/code> in my repo has all the commands for you to compile, test, and benchmark all the kernels I have. Happy testing&amp;hellip;&lt;/p>
&lt;h1 id="wrapping-up">Wrapping up&lt;/h1>
&lt;blockquote>
&lt;p>9-1-1, there&amp;rsquo;s a psycho here. What do you mean reading is not a crime? What kind of lunatic reads a 29,000 word GPU programming blog about some three-year old algorithm? I have to ask him whether he just skimmed it or skipped to the end? Oh, well, fair point.&lt;/p>
&lt;/blockquote>
&lt;p>If you&amp;rsquo;ve even dug into a few sections of this post, you&amp;rsquo;ve probably realized how unforgiving CuTe is. Reading that it&amp;rsquo;s a &amp;ldquo;library&amp;rdquo; is so misleading in the sense that it&amp;rsquo;s just a template engine that hides some additions and divisions behind some friendly words. It&amp;rsquo;s a starter pack without a single instruction or label on it. When you expect it to handle some basic shapes or strides on your behalf, it spits in your face by telling you to specify every single dim and stride by yourself &amp;ndash; and when you inevitably get it wrong, it&amp;rsquo;ll give you a thousand-line stack trace that sends you straight back to the starter room.&lt;/p>
&lt;p>But, you&amp;rsquo;ll also realize that&amp;rsquo;s why it&amp;rsquo;s so powerful. You have control over every single knob and every single decision. The hardware is at your fingertips and you have the full potential to unlock as much or little of it as you want. It&amp;rsquo;s an F1 steering wheel with a thousand switches, Federer&amp;rsquo;s 90-sq inch racket, or a $50,000 cinema camera rig. The layouts might hide some manual PTX but the hardware sits in plain view.&lt;/p>
&lt;p>Once you&amp;rsquo;ve written something as complex as FA2 in CuTe, you&amp;rsquo;ll understand swizzling, cache reuse, tiled copies, or MMAs better than anyone who slapped together some kernel in Triton on a Thursday afternoon. In 2026, there&amp;rsquo;s no such thing as &lt;em>just&lt;/em> a Triton engineer &amp;ndash; the GPU world is moving so rapidly there&amp;rsquo;s no stable baseline mature enough for such a profession to exist. There&amp;rsquo;s no doubt that it will happen some day, but for now, people who spend their days writing Triton right now spent years gruelling away doing the stuff we just did together. So even though there are dozens of fancy tools being written right now, getting your hands dirty is still such a powerful asset. When the day comes where ClaudeGPT Gemini Pro Max is doing all of the work, there still won&amp;rsquo;t be a potion for understanding.&lt;sup id="fnref:10">&lt;a href="#fn:10" class="footnote-ref" role="doc-noteref">10&lt;/a>&lt;/sup>&lt;/p>
&lt;h2 id="whats-left">What&amp;rsquo;s Left&lt;/h2>
&lt;p>Even though we covered all concepts in more detail than the teacher who hates you grades your homework, we never really touched on the empirical backbone that drives kernel engineering. Everything we&amp;rsquo;ve done must be validated via test &amp;ndash; print statements, shape checking, testing, benchmarking. Even after your code runs and works, you have to manually profile your FLOPS, your register pressure, occupancy, SMEM usage, and bank conflicts using tools like &lt;a class="link" href="https://developer.nvidia.com/nsight-compute" target="_blank" rel="noopener"
>Nsight Compute&lt;/a>. Only then can you see you can improve your performance by 42% by optimizing GMEM access patterns or something. This is a skill in itself, but knowing the concepts as deeply as we do now is absolutely critical for interpreting any charts. I leave this as an exercise for the reader &amp;ndash; maybe you&amp;rsquo;ll find bugs or problems with my explanations that I might not even be aware about. If so, please cite me on your revolutionary research paper so I get to share some of the credit.&lt;/p>
&lt;p>We also only covered a slightly out-of-date algorithm running on a somewhat-out-of-date GPU. All the big AI labs are paying the big bucks to optimize on top of the fancier H100s or B200s &amp;ndash; bajillion dollar chips made specifically for your AI girlfriend to break your heart at 2am. We&amp;rsquo;re already on FA3 and even FA4. Newer hardware supports things like:&lt;/p>
&lt;ul>
&lt;li>FP8/FP4 MMA&lt;/li>
&lt;li>Warp specialization: producer warps use the new Tensor Memory Accelerator (TMA) to load things while consumer warps perform warp-group MMAs at lightning speed. The whole data/compute ratio is completely different.&lt;/li>
&lt;li>New async paradigms with more specific barriers and fences.&lt;/li>
&lt;/ul>
&lt;p>And much more. Every hardware generation introduces more and more fancy things and the labs are racing just to keep up. The GPU world is tiny and the break-in period is so challenging that it takes months just to get some existing kernel loosely-optimized for the new-gen hardware. The hardware is becoming more and more mature, but you never know if some new architecture looming on the horizon will completely shake everything up. Some companies are betting their existence on that happening.&lt;/p>
&lt;p>However, having such a fundamental understanding of FA2 is enough that these new paradigms will feel just like a small new layer on top of all the complexity you&amp;rsquo;ve already learned. And that&amp;rsquo;s the silver lining. Everything only becomes more digestible from here.&lt;/p>
&lt;h2 id="on-writing-this">On Writing This&lt;/h2>
&lt;p>Welcome to the true epilogue. This is by far the longest thing I have ever written in my life. I had never even written a technical blog about programming before, but hey, there&amp;rsquo;s always a first. I sat down in late March hoping to learn a little bit about all this GPU kernel hubbub I&amp;rsquo;ve been hearing about. Like a starter player in a level 1000 dungeon, I had no idea what kind of beasts were lurking just past the first room.&lt;/p>
&lt;p>This journey has been extremely tiring yet rewarding. After I got my kernel to compile for the first time, I felt like I finally understood after weeks of gruelling work. But the funny thing is most of my real understanding came from writing this blog. I spent almost a hundred hours working on this &amp;ndash; I drew diagrams in Canva only to delete them thinking I made them wrong, only to redraw them because I was actually right the first time. Understanding &lt;code>sVtNoSwizzle&lt;/code> took like five passes over three weeks &amp;ndash; I would think I had it, hallucinate a plausible explanation, and eventually have to do a three-hour documentation deep-dive. Only then was I enlightened enough to realize the lovely docs don&amp;rsquo;t always explain crap, either. It&amp;rsquo;s been a true struggle. Whether you read only this last paragraph or the whole thing, the fact that you even read some of it means a lot to me.&lt;/p>
&lt;p>Anyway, welcome to the club. Your tears and sweat are non-refundable.&lt;/p>
&lt;h1 id="resources">Resources&lt;/h1>
&lt;h2 id="this-blogs-code">This blog&amp;rsquo;s code&lt;/h2>
&lt;ul>
&lt;li>&lt;a class="link" href="https://github.com/cloudui/cuda-triton" target="_blank" rel="noopener"
>&lt;code>cloudui/cuda-triton&lt;/code>&lt;/a>&lt;/li>
&lt;/ul>
&lt;h2 id="production-fa2">Production FA2&lt;/h2>
&lt;ul>
&lt;li>&lt;a class="link" href="https://github.com/Dao-AILab/flash-attention" target="_blank" rel="noopener"
>&lt;code>Dao-AILab/flash-attention&lt;/code>&lt;/a>&lt;/li>
&lt;li>&lt;a class="link" href="https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src" target="_blank" rel="noopener"
>&lt;code>csrc/flash_attn/src/&lt;/code>&lt;/a> &amp;ndash; C++ source. Most relevant files: &lt;code>flash_fwd_kernel.h&lt;/code>, &lt;code>kernel_traits.h&lt;/code>.&lt;/li>
&lt;li>&lt;a class="link" href="https://arxiv.org/pdf/2307.08691" target="_blank" rel="noopener"
>FlashAttention-2 paper (Dao, 2023)&lt;/a>&lt;/li>
&lt;li>&lt;a class="link" href="https://tridao.me/blog/2024/flash3/" target="_blank" rel="noopener"
>FlashAttention-3&lt;/a>&lt;/li>
&lt;/ul>
&lt;h2 id="cutlass--cute-official">CUTLASS / CuTe official&lt;/h2>
&lt;ul>
&lt;li>&lt;a class="link" href="https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html" target="_blank" rel="noopener"
>CuTe/CUTLASS Docs&lt;/a>&lt;/li>
&lt;li>&lt;a class="link" href="https://github.com/NVIDIA/cutlass" target="_blank" rel="noopener"
>NVIDIA/cutlass repo&lt;/a>. The source files are in &lt;code>include/&lt;/code>.&lt;/li>
&lt;/ul>
&lt;h2 id="blogs">Blogs&lt;/h2>
&lt;ul>
&lt;li>&lt;a class="link" href="https://leimao.github.io/blog/CuTe-Swizzle/" target="_blank" rel="noopener"
>Lei Mao&amp;rsquo;s Blog&lt;/a>. Browse around, he has great explanations on CuTe.&lt;/li>
&lt;li>&lt;a class="link" href="https://developer.nvidia.com/blog" target="_blank" rel="noopener"
>NVIDIA Blogs&lt;/a>&lt;/li>
&lt;li>&lt;a class="link" href="https://lubits.ch/flash/" target="_blank" rel="noopener"
>Sonny&amp;rsquo;s Blog &amp;ndash; FA2 from Scratch&lt;/a>. Literally raw inline PTX for those with time.&lt;/li>
&lt;li>&lt;a class="link" href="https://ita9naiwa.github.io/mlsys/2023/11/16/attention-cuda.html" target="_blank" rel="noopener"
>Hyunsung Lee&amp;rsquo;s Blog &amp;ndash; CUDA + Attention&lt;/a>&lt;/li>
&lt;/ul>
&lt;h2 id="nvidia-reference-docs">NVIDIA reference docs&lt;/h2>
&lt;ul>
&lt;li>&lt;a class="link" href="https://docs.nvidia.com/cuda/cuda-programming-guide/index.html" target="_blank" rel="noopener"
>CUDA C++ Programming Guide&lt;/a>. Literally every CUDA concept from the Big Boss himself.&lt;/li>
&lt;li>&lt;a class="link" href="https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async" target="_blank" rel="noopener"
>PTX ISA Dictionary&lt;/a>. For those who love torture.&lt;/li>
&lt;/ul>
&lt;h1 id="appendix">Appendix&lt;/h1>
&lt;div class="footnotes" role="doc-endnotes">
&lt;hr>
&lt;ol>
&lt;li id="fn:1">
&lt;p>Tri Dao original FA2 paper: &lt;a class="link" href="https://arxiv.org/pdf/2307.08691" target="_blank" rel="noopener"
>https://arxiv.org/pdf/2307.08691&lt;/a>&amp;#160;&lt;a href="#fnref:1" class="footnote-backref" role="doc-backlink">&amp;#x21a9;&amp;#xfe0e;&lt;/a>&lt;/p>
&lt;/li>
&lt;li id="fn:2">
&lt;p>Ok sorry. CUTLASS is a a cool sword and also CUDA Templates for Linear Algebra Subroutines and Solvers&amp;#160;&lt;a href="#fnref:2" class="footnote-backref" role="doc-backlink">&amp;#x21a9;&amp;#xfe0e;&lt;/a>&lt;/p>
&lt;/li>
&lt;li id="fn:3">
&lt;p>PTX (Parallel Thread Execution) and SASS (Streaming Assembler) are two different levels of NVIDIA&amp;rsquo;s GPU instructions. PTX is a set of virtual instructions that map to hardware-level SASS instructions based on the underlying architecture.&amp;#160;&lt;a href="#fnref:3" class="footnote-backref" role="doc-backlink">&amp;#x21a9;&amp;#xfe0e;&lt;/a>&amp;#160;&lt;a href="#fnref1:3" class="footnote-backref" role="doc-backlink">&amp;#x21a9;&amp;#xfe0e;&lt;/a>&amp;#160;&lt;a href="#fnref2:3" class="footnote-backref" role="doc-backlink">&amp;#x21a9;&amp;#xfe0e;&lt;/a>&lt;/p>
&lt;/li>
&lt;li id="fn:4">
&lt;p>A word is 4 bytes (32-bits). This term is slightly ambiguous based on architecture or context, e.g. a word for a n-bit CPU processor means n-bits. But in CUDA, it almost always means 32-bits. Other alternatives include scalars, floats, or bank-widths, but we will stick to the word &amp;ldquo;word&amp;rdquo; when discussing bank conflicts.&amp;#160;&lt;a href="#fnref:4" class="footnote-backref" role="doc-backlink">&amp;#x21a9;&amp;#xfe0e;&lt;/a>&lt;/p>
&lt;/li>
&lt;li id="fn:5">
&lt;p>Four-transaction 512-byte load explanation: &lt;a class="link" href="https://forums.developer.nvidia.com/t/128-bit-access-bank-conflict/287039/5" target="_blank" rel="noopener"
>https://forums.developer.nvidia.com/t/128-bit-access-bank-conflict/287039/5&lt;/a>&amp;#160;&lt;a href="#fnref:5" class="footnote-backref" role="doc-backlink">&amp;#x21a9;&amp;#xfe0e;&lt;/a>&lt;/p>
&lt;/li>
&lt;li id="fn:6">
&lt;p>CuTe &lt;code>partition_fragment()&lt;/code> source: &lt;a class="link" href="https://github.com/NVIDIA/cutlass/blob/e406c186f510a15091cce01f782020ceb7ba8eb5/include/cute/atom/mma_atom.hpp#L508" target="_blank" rel="noopener"
>https://github.com/NVIDIA/cutlass/blob/e406c186f510a15091cce01f782020ceb7ba8eb5/include/cute/atom/mma_atom.hpp#L508&lt;/a>&amp;#160;&lt;a href="#fnref:6" class="footnote-backref" role="doc-backlink">&amp;#x21a9;&amp;#xfe0e;&lt;/a>&amp;#160;&lt;a href="#fnref1:6" class="footnote-backref" role="doc-backlink">&amp;#x21a9;&amp;#xfe0e;&lt;/a>&lt;/p>
&lt;/li>
&lt;li id="fn:7">
&lt;p>CuTe &lt;code>make_fragment_like()&lt;/code> source: &lt;a class="link" href="https://github.com/NVIDIA/cutlass/blob/e406c186f510a15091cce01f782020ceb7ba8eb5/include/cute/tensor_impl.hpp#L463" target="_blank" rel="noopener"
>https://github.com/NVIDIA/cutlass/blob/e406c186f510a15091cce01f782020ceb7ba8eb5/include/cute/tensor_impl.hpp#L463&lt;/a>&amp;#160;&lt;a href="#fnref:7" class="footnote-backref" role="doc-backlink">&amp;#x21a9;&amp;#xfe0e;&lt;/a>&lt;/p>
&lt;/li>
&lt;li id="fn:8">
&lt;p>&lt;a class="link" href="https://developer.nvidia.com/blog/using-cuda-warp-level-primitives/" target="_blank" rel="noopener"
>https://developer.nvidia.com/blog/using-cuda-warp-level-primitives/&lt;/a>&amp;#160;&lt;a href="#fnref:8" class="footnote-backref" role="doc-backlink">&amp;#x21a9;&amp;#xfe0e;&lt;/a>&lt;/p>
&lt;/li>
&lt;li id="fn:9">
&lt;p>Oxford shuffle lecture notes, p.6 for XOR warp shuffle: &lt;a class="link" href="https://people.maths.ox.ac.uk/gilesm/cuda/lecs/lec4.pdf" target="_blank" rel="noopener"
>https://people.maths.ox.ac.uk/gilesm/cuda/lecs/lec4.pdf&lt;/a>&amp;#160;&lt;a href="#fnref:9" class="footnote-backref" role="doc-backlink">&amp;#x21a9;&amp;#xfe0e;&lt;/a>&lt;/p>
&lt;/li>
&lt;li id="fn:10">
&lt;p>Honestly, maybe they&amp;rsquo;ll figure out how to inject our brains directly with learning modules and everyone becomes an expert tomorrow.&amp;#160;&lt;a href="#fnref:10" class="footnote-backref" role="doc-backlink">&amp;#x21a9;&amp;#xfe0e;&lt;/a>&lt;/p>
&lt;/li>
&lt;/ol>
&lt;/div></description></item></channel></rss>