Awesome
<div align='center'>
<img src='https://github.com/user-attachments/assets/b2578723-b7a7-4d8f-bcd1-5008947b808a' >
</div>
<div align='center'>
<img src=https://cdn.rawgit.com/sindresorhus/awesome/d7305f38d29fed78fa85652e3a63e154dd8e8829/media/badge.svg >
<img src=https://img.shields.io/badge/Language-CUDA-brightgreen.svg >
<img src=https://img.shields.io/github/watchers/DefTruth/cuda-learn-note?color=9cc >
<img src=https://img.shields.io/github/forks/DefTruth/cuda-learn-note.svg?style=social >
<img src=https://img.shields.io/github/stars/DefTruth/cuda-learn-note.svg?style=social >
<img src=https://img.shields.io/badge/Release-v2.6-brightgreen.svg >
<img src=https://img.shields.io/badge/License-GPLv3.0-turquoise.svg >
</div>
<div id="contents"></div>
📚 Modern CUDA Learn Notes with PyTorch for Beginners: It includes Tensor/CUDA Cores, TF32/F16/BF16/F8, 📖150+ CUDA Kernels🔥🔥(Easy -> Hard++) with PyTorch bindings, 📖100+ LLM/VLM/CV/CUDA/CuTe🔥 blogs, 📖toy-hgemm⚡️⚡️ which can achieve 98%~100%
performance of cuBLAS, and 📖flash-attention-mma⚡️⚡️ using Tensor Cores with pure MMA PTX. Welcome to 🌟👆🏻star this repo to support me, many thanks ~ 🎉🎉
<div id="hgemm-sgemm"></div>
<div align='center'>
<img src='https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99' height="170px" width="270px">
<img src='https://github.com/user-attachments/assets/05ef4f5e-d999-48ea-b58e-782cffb24e85' height="170px" width="270px">
<img src='https://github.com/user-attachments/assets/9472e970-c083-4b31-9252-3eeecc761078' height="170px" width="270px">
</div>
Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores algorithm, the HGEMM (WMMA/MMA/CuTe)
in this repo (blue
🔵) can achieve 98%~100%
of its (orange
🟠) performance. Please check toy-hgemm library⚡️⚡️ or hgemm-tensorcores-mma⚡️⚡️ repo for more details.
CUDA Cores | Sliced K (Loop over K) | Tile Block (BMxBK) | Tile Thread (t 8x8) |
---|
✔️ | ✔️ | ✔️ | ✔️ |
WMMA (m16n16k16) | MMA (m16n8k16) | Pack LDST (128 bits) | SMEM Padding |
✔️ | ✔️ | ✔️ | ✔️ |
Copy Async | Tile MMA (More Threads) | Tile Warp (More Values) | Multi Stages (2/3/4) |
✔️ | ✔️ | ✔️ | ✔️ |
Reg Double Buffers | Block Swizzle | Warp Swizzle | SMEM Swizzle (CuTe) |
✔️ | ✔️ | ✔️ | ✔️ |
Collective Store (Warp Shfl) | Row Major (NN) | Col Major (TN) | SGEMM FP32/TF32 |
✔️ | ✔️ | ✔️ | ✔️ |
I have also implemented FlashAttention-2 using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Shared KV SMEM, Fully Shared QKV SMEM, Prefetch Q s2r, Collective Store, etc. Please refer to flash-attention-mma⚡️⚡️ for more details.
Tensor Cores | Loop over Seqlen/Headdim | Tile Block (Br, Bc) | MMA (m16n8k16) |
---|
✔️ | ✔️ | ✔️ | ✔️ |
Pack LDST (128 bits) | SMEM Padding | Copy Async | Tile MMA (More Threads) |
✔️ | ✔️ | ✔️ | ✔️ |
Tile Warp (More Values) | Multi Stages (1/2) | Collective Store (Shfl) | Split KV/Q |
✔️ | ✔️ | ✔️ | ✔️ |
Shared QKV/KV SMEM | Prefetch Q s2r | Prefetch K/V g2s | SMEM/Block Swizzle |
✔️ | ✔️ | ✔️ | ? |
Currently, for small-scale attention (B<=4, H <=48, SeqLen <= 8192)
can run faster than offical FA2 on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop):
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
-------------------------------------------B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
torch(unfused): ['-0.00514603 ', '0.05783081 ', '-0.00026727 '], time:20.999861ms, TFLOPS:6.67 (+0.00%)
mma(split-kv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:5.120730ms, TFLOPS:27.36 (+310.10%)
mma(split-kv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:5.004287ms, TFLOPS:28.00 (+2.33%)
mma(split-q+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:3.462291ms, TFLOPS:40.47 (+44.54%)
mma(split-q+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:3.658915ms, TFLOPS:38.30
mma(split-q+share-qkv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.551699ms, TFLOPS:54.91 (+35.69%)
mma(split-q+share-qkv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.532172ms, TFLOPS:55.34 (+0.77%)
mma(split-q+share-kv+stage1): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.776575ms, TFLOPS:50.46
mma(split-q+share-kv+stage2): ['-0.00511169 ', '0.05795288 ', '-0.00029612 '], time:2.596927ms, TFLOPS:53.96
(flash): ['-0.00516129 ', '0.05783081 ', '-0.00027728 '], time:3.776550ms, TFLOPS:37.10
----------------------------------------------------------------------------------------------------------------------------------
The Split KV
and Split Q
implementations have been carried out in flash-attention-mma⚡️⚡️ for performance comparison. The Split KV
method, which involves splitting all QKV across MMA (Warps), is slower than Split Q
policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
- 📚 Split KV (Basic, FlashAttention-1)
<div id="mma-split-kv"></div>
// Split QKV across MMA(Warps) using naive matmul MMA&Warp tiling policy.
// case: The layout of 8 MMA(2x4) [after] kWarpTileSeqLenQxkWarpTileSeqLenK(2x2) -> 32x2,32x2=64x64:
// | [64,64] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 |
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_kv_kernel(half* Q, half* K, half* V, half* O, ...);
- 📚 Split Q (Faster, FlashAttention-2)
<div id="mma-split-q"></div>
// Split Q across MMA(Warps) and keep access KV for all MMA(Warps),
// in order to reduce the comm between warps via smem and warp shuffle.
// case: MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
// | 64x64 | warp_KV 0 |
// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_kernel(half* Q, half* K, half* V, half* O, ...);
- 📚 Split Q + Shared KV SMEM (1/2 SRAM vs FA2)
<div id="mma-share-kv"></div>
// K, V shared the same shared memory, improve block occupancy.
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, half* K, half* V, half* O, ...);
- 📚 Split Q + Fully Shared QKV SMEM (1/4 SRAM vs FA2)
<div id="mma-share-qkv"></div>
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy
// and reduce Q SMEM IO-Access.
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half* O, ...);
- 📚 Split Q + QK Fine-grained Tiling (O(16xd) SRAM vs FA2 O(4xBrxd) SRAM,
Headdim -> 1024
)
<div id="mma-tiling-qk"></div>
// Fine-grained tiling (MMA level) for Q/K, it cause constant SRAM size 64*kMmaAtomK for Q/K,
// and O(kMmaAtomK*d) SRAM complexity for V, thus, the SRAM complexity is O(kMmaAtomK*d).
// Thus, we can extend D(headdim) to 1024. Performance is stay tuned for updates ~
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
©️Citations🎉🎉
@misc{CUDA-Learn-Notes@2024,
title={CUDA-Learn-Notes: A Modern CUDA Learn Notes with PyTorch for Beginners},
url={https://github.com/DefTruth/CUDA-Learn-Notes},
note={Open-source software available at https://github.com/DefTruth/CUDA-Learn-Notes},
author={DefTruth etc},
year={2024}
}
📖 150+ CUDA Kernels 🔥🔥 (Easy -> Hard++) (©️back👆🏻)
<div id="cuda-kernel"></div>
The kernels listed here will guide you through a step-by-step progression, ranging from easy to very challenging topics. The Workflow will look like: custom CUDA kernel impl -> PyTorch Python bindings -> Run tests. 👉TIPS: *
= Tensor Cores (WMMA, MMA, CuTe), otherwise, CUDA Cores; /
= not supported; ✔️
= supported; ❔
= TODO. Contents:
📚 Easy and 📚 Medium sections cover fundamental operations such as element-wise, mat_trans, warp/block reduce, online-softmax, nms, layer-norm, rms-norm, dot-prod etc. 📚 Hard and 📚 Hard++ sections delve deeper into advanced topics, primarily focusing on operations like sgemv, sgemm, hgemv, hgemm and flash-attention
. These sections also provide numerous kernels implemented using Tensor Cores with pure MMA PTX instructions.
📚 Easy ⭐️ & Medium ⭐️⭐️ (©️back👆🏻)
<div id="cuda-kernel-easy-medium"></div>
📚 Hard ⭐⭐⭐️⭐️ & Hard++ ⭐️⭐️⭐️⭐️⭐️ (©️back👆🏻)
<div id="cuda-kernel-hard"></div>
<!---
|📖 CUDA Kernel| 📖 Elem DType| 📖 Acc DType| 📖 Docs | 📖 Level |
|:---|:---|:---|:---|:---|
| ✔️ [nsys/ncu(timeline/ptx/sass)](./kernels/nvidia-nsight/)|/|/|[link](./kernels/nvidia-nsight/)|⭐️|
| ✔️ [flash_attn_mma_stages_split_kv*](./kernels/flash-attn/mma/flash_attn_mma_split_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma_stages_split_q*](./kernels/flash-attn/mma/flash_attn_mma_split_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma_stages...shared_kv*](./kernels/flash-attn/mma/flash_attn_mma_share_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma_stages...shared_qkv*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma_stages...tiling_qk*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ✔️ [sgemm_naive_f32](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️|
| ✔️ [sgemm_sliced_k_f32](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k_f32x4](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k...bcf](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k...dbuf](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k16...dbuf](./kernels/sgemm/sgemm_async.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k16...async](./kernels/sgemm/sgemm_async.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_wmma_m16n16k8...stages*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_wmma_m16n16k8...swizzle*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_naive_f16](./kernels/hgemm/naive/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️|
| ✔️ [hgemm_sliced_k_f16](./kernels/hgemm/naive/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_t_8x8_sliced_k_f16x4](./kernels/hgemm/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_t_8x8_sliced_k_f16x4_pack](./kernels/hgemm/naive/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_t_8x8_sliced_k_f16x8_pack](./kernels/hgemm/naive/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_t_8x8_sliced_k...dbuf](./kernels/hgemm/naive/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_t_8/16x8...k16/32...dbuf](./kernels/hgemm/naive/hgemm_async.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_t_8/16x8...k16/32...async](./kernels/hgemm/naive/hgemm_async.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...naive*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...mma4x2*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...mma4x4*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...dbuf*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m32n8k16....dbuf*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...stages*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...swizzle*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_m16n8k16...naive*](./kernels/hgemm/mma/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_m16n8k16...mma2x4*](./kernels/hgemm/mma/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_m16n8k16...stages*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_m16n8k16...swizzle*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_stages{swizzle}...cute*](./kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_cublas*](./kernels/hgemm/cublas/hgemm_cublas.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️|
| ✔️ [sgemv_k32_f32](./kernels/sgemv/sgemv.cu)|f32|f32|[link](./kernels/sgemv/)|⭐️⭐️⭐️|
| ✔️ [sgemv_k128_f32x4](./kernels/sgemv/sgemv.cu)|f32|f32|[link](./kernels/sgemv/)|⭐️⭐️⭐️|
| ✔️ [sgemv_k16_f32](./kernels/sgemv/sgemv.cu)|f32|f32|[link](./kernels/sgemv/)|⭐️⭐️⭐️|
| ✔️ [hgemv_k32_f16](./kernels/hgemv/hgemv.cu)|f16|f16|[link](./kernels/hgemv/)|⭐️⭐️⭐️|
| ✔️ [hgemv_k128_f16x4](./kernels/hgemv/hgemv.cu)|f16|f16|[link](./kernels/hgemv/)|⭐️⭐️⭐️|
| ✔️ [hgemv_k16_f16](./kernels/hgemv/hgemv.cu)|f16|f16|[link](./kernels/hgemv/)|⭐️⭐️⭐️|
| ✔️ [elementwise_f32](./kernels/elementwise/elementwise.cu)|f32|/|[link](./kernels/elementwise/)|⭐️|
| ✔️ [elementwise_f32x4](./kernels/elementwise/elementwise.cu)|f32|/|[link](./kernels/elementwise/)|⭐️|
| ✔️ [elementwise_f16](./kernels/elementwise/elementwise.cu)|f16|/|[link](./kernels/elementwise/)|⭐️|
| ✔️ [elementwise_f16x2](./kernels/elementwise/elementwise.cu)|f16|/|[link](./kernels/elementwise/)|⭐️|
| ✔️ [elementwise_f16x8](./kernels/elementwise/elementwise.cu)|f16|/|[link](./kernels/elementwise/)|⭐️|
| ✔️ [elementwise_f16x8_pack](./kernels/elementwise/elementwise.cu)|f16|/|[link](./kernels/elementwise/)|⭐️⭐️|
| ✔️ [histogram_i32](./kernels/histogram/histogram.cu)|i32|/|[link](./kernels/histogram/)|⭐️|
| ✔️ [histogram_i32x4](./kernels/histogram/histogram.cu)|i32|/|[link](./kernels/histogram/)|⭐️|
| ✔️ [sigmoid_f32](./kernels/sigmoid/sigmoid.cu)|f32|/|[link](./kernels/sigmoid/)|⭐️|
| ✔️ [sigmoid_f32x4](./kernels/sigmoid/sigmoid.cu)|f32|/|[link](./kernels/sigmoid/)|⭐️|
| ✔️ [sigmoid_f16](./kernels/sigmoid/sigmoid.cu)|16|/|[link](./kernels/sigmoid/)|⭐️|
| ✔️ [sigmoid_f16x2](./kernels/sigmoid/sigmoid.cu)|f16|/|[link](./kernels/sigmoid/)|⭐️|
| ✔️ [sigmoid_f16x8](./kernels/sigmoid/sigmoid.cu)|f16|/|[link](./kernels/sigmoid/)|⭐️|
| ✔️ [sigmoid_f16x8_pack](./kernels/sigmoid/sigmoid.cu)|f16|/|[link](./kernels/sigmoid/)|⭐️⭐️|
| ✔️ [relu_f32](./kernels/relu/relu.cu)|f32|/|[link](./kernels/relu/)|⭐️|
| ✔️ [relu_f32x4](./kernels/relu/relu.cu)|f32|/|[link](./kernels/relu/)|⭐️|
| ✔️ [relu_f16](./kernels/relu/relu.cu)|f16|/|[link](./kernels/relu/)|⭐️|
| ✔️ [relu_f16x2](./kernels/relu/relu.cu)|f16|/|[link](./kernels/relu/)|⭐️|
| ✔️ [relu_f16x8](./kernels/relu/relu.cu)|f16|/|[link](./kernels/relu/)|⭐️|
| ✔️ [relu_f16x8_pack](./kernels/relu/relu.cu)|f16|/|[link](./kernels/relu/)|⭐️⭐️|
| ✔️ [gelu_f32](./kernels/gelu/gelu.cu)|f32|/|[link](./kernels/gelu/)|⭐️|
| ✔️ [gelu_f32x4](./kernels/gelu/gelu.cu)|f32|/|[link](./kernels/gelu/)|⭐️|
| ✔️ [gelu_f16](./kernels/gelu/gelu.cu)|f16|/|[link](./kernels/gelu/)|⭐️|
| ✔️ [gelu_f16x2](./kernels/gelu/gelu.cu)|f16|/|[link](./kernels/gelu/)|⭐️|
| ✔️ [gelu_f16x8](./kernels/gelu/gelu.cu)|f16|/|[link](./kernels/gelu/)|⭐️|
| ✔️ [gelu_f16x8_pack](./kernels/gelu/gelu.cu)|f16|/|[link](./kernels/gelu/)|⭐️⭐️|
| ✔️ [swish_f32](./kernels/swish/swish.cu)|f32|/|[link](./kernels/swish/)|⭐️|
| ✔️ [swish_f32x4](./kernels/swish/swish.cu)|f32|/|[link](./kernels/swish/)|⭐️|
| ✔️ [swish_f16](./kernels/swish/swish.cu)|f16|/|[link](./kernels/swish/)|⭐️|
| ✔️ [swish_f16x2](./kernels/swish/swish.cu)|f16|/|[link](./kernels/swish/)|⭐️|
| ✔️ [swish_f16x8](./kernels/swish/swish.cu)|f16|/|[link](./kernels/swish/)|⭐️|
| ✔️ [swish_f16x8_pack](./kernels/swish/swish.cu)|f16|/|[link](./kernels/swish/)|⭐️⭐️|
| ✔️ [embedding_f32](./kernels/embedding/embedding.cu)|f32|/|[link](./kernels/embedding/)|⭐️|
| ✔️ [embedding_f32x4](./kernels/embedding/embedding.cu)|f32|/|[link](./kernels/embedding/)|⭐️|
| ✔️ [embedding_f32x4_pack](./kernels/embedding/embedding.cu)|f32|/|[link](./kernels/embedding/)|⭐️|
| ✔️ [embedding_f16](./kernels/embedding/embedding.cu)|f16|/|[link](./kernels/embedding/)|⭐️|
| ✔️ [embedding_f16x2](./kernels/embedding/embedding.cu)|f16|/|[link](./kernels/embedding/)|⭐️|
| ✔️ [embedding_f16x8](./kernels/embedding/embedding.cu)|f16|/|[link](./kernels/embedding/)|⭐️|
| ✔️ [embedding_f16x8_pack](./kernels/embedding/embedding.cu)|f16|/|[link](./kernels/embedding/)|⭐️⭐️|
| ✔️ [mat_trans_f32_col2row{2d}](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️|
| ✔️ [mat_trans_f32_row2col{2d}](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️|
| ✔️ [mat_trans_f32_diagonal2d](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️|
| ✔️ [mat_trans_f32x4_col2row{2d}](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️|
| ✔️ [mat_trans_f32x4_row2col{2d}](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️|
| ✔️ [warp_reduce_[all]](./kernels/reduce/block_all_reduce.cu)|all|all|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f32_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f32x4_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f16_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f16_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f16x2_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f16x2_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f16x8_pack_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_f16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16x2_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16x2_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16x8_pack_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_fp8_e4m3_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_fp8_e5m2_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_fp8_e4m3x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_fp8_e5m2x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_i8_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [reduce_i8x16_pack_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [dot_product_f32](./kernels/dot-product/dot_product.cu)|f32|f32|[link](./kernels/dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f32x4](./kernels/dot-product/dot_product.cu)|f32|f32|[link](./kernels/dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f16_f32](./kernels/dot-product/dot_product.cu)|f16|f32|[link](./kernels/dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f16x2_f32](./kernels/dot-product/dot_product.cu)|f16|f32|[link](./kernels/dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f16x8_pack_f32](./kernels/dot-product/dot_product.cu)|f16|f32|[link](./kernels/dot-product/)|⭐️⭐️|
| ✔️ [softmax_f32(fence)](./kernels/softmax/softmax.cu)|f32|f32|[link](./kernels/softmax/)|⭐️⭐️|
| ✔️ [softmax_f32x4(fence)](./kernels/softmax/softmax.cu)|f32|f32|[link](./kernels/softmax/)|⭐️⭐️|
| ✔️ [softmax_f32](./kernels/softmax/softmax.cu)|f32|f32|[link](./kernels/softmax/)|⭐️⭐️|
| ✔️ [softmax_f32x4](./kernels/softmax/softmax.cu)|f32|f32|[link](./kernels/softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f32](./kernels/softmax/softmax.cu)|f32|f32|[link](./kernels/softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f32x4](./kernels/softmax/softmax.cu)|f32|f32|[link](./kernels/softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f16_f32](./kernels/softmax/softmax.cu)|f16|f32|[link](./kernels/softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f16x2_f32](./kernels/softmax/softmax.cu)|f16|f32|[link](./kernels/softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f16x8_pack_f32](./kernels/softmax/softmax.cu)|f16|f32|[link](./kernels/softmax/)|⭐️⭐️|
| ✔️ [online_safe_softmax_f32](./kernels/softmax/softmax.cu)|f32|f32|[link](./kernels/softmax/)|⭐️⭐️|
| ✔️ [online_safe_softmax_f32x4_pack](./kernels/softmax/softmax.cu)|f32|f32|[link](./kernels/softmax/)|⭐️⭐️|
| ✔️ [rope_f32](./kernels/rope/rope.cu)|f32|f32|[link](./kernels/rope/)|⭐️⭐️|
| ✔️ [rope_f32x4_pack](./kernels/rope/rope.cu)|f32|f32|[link](./kernels/rope/)|⭐️⭐️|
| ✔️ [layer_norm_f32](./kernels/layer-norm/layer_norm.cu)|f32|f32|[link](./kernels/layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f32x4](./kernels/layer-norm/layer_norm.cu)|f32|f32|[link](./kernels/layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16_f16](./kernels/layer-norm/layer_norm.cu)|f16|f16|[link](./kernels/layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x2_f16](./kernels/layer-norm/layer_norm.cu)|f16|f16|[link](./kernels/layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x8_f16](./kernels/layer-norm/layer_norm.cu)|f16|f16|[link](./kernels/layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x8_pack_f16](./kernels/layer-norm/layer_norm.cu)|f16|f16|[link](./kernels/layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x8_pack_f32](./kernels/layer-norm/layer_norm.cu)|f16|f32|[link](./kernels/layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16_f32](./kernels/layer-norm/layer_norm.cu)|f16|f32|[link](./kernels/layer-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f32](./kernels/rms-norm/rms_norm.cu)|f32|f32|[link](./kernels/rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f32x4](./kernels/rms-norm/rms_norm.cu)|f32|f32|[link](./kernels/rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16_f16](./kernels/rms-norm/rms_norm.cu)|f16|f16|[link](./kernels/rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x2_f16](./kernels/rms-norm/rms_norm.cu)|f16|f16|[link](./kernels/rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_f16](./kernels/rms-norm/rms_norm.cu)|f16|f16|[link](./kernels/rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_f32](./kernels/rms-norm/rms_norm.cu)|f16|f32|[link](./kernels/rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_pack_f16](./kernels/rms-norm/rms_norm.cu)|f16|f16|[link](./kernels/rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_pack_f32](./kernels/rms-norm/rms_norm.cu)|f16|f32|[link](./kernels/rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16_f32](./kernels/rms-norm/rms_norm.cu)|f16|f32|[link](./kernels/rms-norm/)|⭐️⭐️|
| ✔️ [nms_f32](./kernels/nms/nms.cu)|f32|/|[link](./kernels/nms)|⭐️⭐️|
| ✔️ [notes v1(deprecated)](./kernels/notes-v1.cu)|f32|f32|/|⭐️|
--->
📖 博客目录
<div id="my-blogs-part-1"></div>
📚 大模型|多模态|Diffusion|推理优化 (本人作者) (©️back👆🏻)
📚 CV推理部署|C++|算法|技术随笔 (本人作者) (©️back👆🏻)
<div id="my-blogs-part-2"></div>
📚 CUTLASS|CuTe|NCCL|CUDA|文章推荐 (其他作者) (©️back👆🏻)
<div id="other-blogs"></div>
💡说明: 本小节整理一些自己比较喜欢的文章。欢迎大家提PR推荐更多优秀的文章!
<div id="License"></div>
GNU General Public License v3.0
<div id="Contribute"></div>
How to contribute? please check 🌤🌤CONTRIBUTE🎉🎉.
<div align='center'>
<a href="https://star-history.com/#DefTruth/CUDA-Learn-Notes&Date">
<picture align='center'>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=DefTruth/CUDA-Learn-Notes&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=DefTruth/CUDA-Learn-Notese&type=Date" />
<img width=450 height=300 alt="Star History Chart" src="https://api.star-history.com/svg?repos=DefTruth/CUDA-Learn-Notes&type=Date" />
</picture>
</a>
</div>