Awesome
<img src='./docs/wmmae.svg' width=200>WMMA API Extension
This extension provides features for
- mapping between memory and fragment (primitive functions)
- operationf for vectors
- loading a vector as a fragment
- storing a fragment as a vector
- C++ interface for
mma
instructions [detail] - Error Correction (TCEC) for SGEMM emulation [detail]
- arithmetic operators for fragments (
+, -, *, /, fma
) [detail] - utils [detail]
- etc
without using extra shared memory.
[!IMPORTANT] Please specify an appropriate virtual architecture for real GPU. For instance, a program which is compiled with
-arch=sm_70
will not work correctly on Ampere GPUs.
Requirements
- CUDA (10.2 or later)
- C++ (17 or later)
Supported architectures / fragment
- sm_70: ((16, 16, 16), fp16/fp32)
- sm_75: ((16, 16, 16), fp16/fp32)
- sm_80: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32)
- sm_89: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32)
- sm_90: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32) (
wgmma
instruction is not supported yet)
Functions
Primitive functions
foreach
This function calculates the mapping of the memory and fragment elements.
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
__shared__ compute_t matrix[16 * 16];
mtk::wmma::foreach<decltype(frag_b)>(
[&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
const auto m = mem_index % 16;
const auto n = mem_index / 16;
for (unsigned i = 0; i < fragment_index_count; i++)
frag_b.x[frag_index_list[i]] = convert_to<half>(matrix[n * 16 + m]);
});
foreach_ij
This function calculates the mapping of the matrix element position (i,j) and fragment elements.
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
__shared__ compute_t matrix[16 * 16];
mtk::wmma::foreach_ij<decltype(frag_b)>(
[&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned i, const unsigned j) {
for (unsigned f = 0; f < fragment_index_count; f++)
frag_b.x[frag_index_list[f]] = convert_to<half>(matrix[j * 16 + i]);
});
foreach_v
For matrix A/B
This function calculates the mapping of a given vector and fragment elements.
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
__shared__ compute_t vector[16];
mtk::wmma::foreach_v<decltype(frag_b)>(
[&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
for (unsigned i = 0; i < fragment_index_count; i++)
frag_b.x[frag_index_list[i]] = convert_to<half>(vector[mem_index]);
});
// is equivalent to `load_vector`
For accumulator
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> frag_c;
__shared__ compute_t vector[16];
mtk::wmma::foreach_v<decltype(frag_c)>(nvcuda::wmma::mem_col_major,
[&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
for (unsigned i = 0; i < fragment_index_count; i++)
vector[mem_index] = convert_to<compute_t>(frag_c.x[frag_index_list[i]]);
});
// is equivalent to `store_vector`
map
This function returns the mapping of matrix element (i, j) and fragment element (tid, fid)
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
unsigned tid_list[2];
unsigned fid_list[2];
unsigned list_size;
mtk::wmma::map<decltype(frag_b)>(tid_list, fid_list, list_size, i, j);
for (unsigned k = 0; k < list_size; k++) {
if ((threadIdx.x & 0x1f) == tid_list[k]) {
frag_b.x[fid_list[k]] = 3.0f;
}
}
Functions for vector
Sample
#include <mma.h>
#include <wmma_extension/wmma_extension.hpp>
__global__ void kernel() {
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> frag_a;
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> frag_c;
__shared__ float vec16[16];
mtk::wmma::load_vector(frag_a, vec16);
mtk::wmma::load_vector(frag_b, vec16);
nvcuda::wmma::fill_fragment(frag_c, 0.0f);
nvcuda::wmma::mma_sync(frag_c, frag_a, frag_b, frag_c);
mtk::wmma::store_vector(vec16, frag_c, nvcuda::wmma::mem_col_major);
}
Other functions
make_identity_matrix / add_eye
- Arguments
- dst_fragment : Destination fragment (
accumulator
) - alpha : diagonal element
- dst_fragment : Destination fragment (
fill_zero
- Argument
- dst_fragment : Destination fragment
Debugging functions
print_fragment
This function output the elements of a fragment.
- Arguments
- frag : Target fragment
- name : printing name of fragment (
char*
, optional)
Publication
@inproceedings{ootomo_wmmae_2023,
author = {Ootomo, Hiroyuki and Yokota, Rio},
title = {Reducing Shared Memory Footprint to Leverage High Throughput on Tensor Cores and Its Flexible API Extension Library},
year = {2023},
series = {HPC Asia '23}
}
LICENSE
MIT