Awesome
CatConv2d
Concat+Conv2d fused all-in-one CUDA kernel extension for Pytorch
- Optimized for Volta, 10% faster than torch concat+conv (batch size = 1)
x = CatConv2d(x_list)
# effectivly equal to
x = torch.cat(x_list, 1)
x = conv.forward(x)
Features
- Conv1x1 (GEMM) CUDA kernel
- Conv3x3 Winograd CUDA kernel, fused with input/weight transform
- Take python list of Tensor as input
Limitations
- Forward only
- Maximum 16 Tensors for the input list
- NCHW layout
- Channel size >= 4 for each input Tensor in the list
- Hardcoded zero padding (padding=1 for Conv3x3)
- Optimized for small batch size
- Optimized on Volta
TODO List
- Backward path
- SASS level optimization
- TensorCore 16bit version
Requirement
- Cuda 10.2
Usage
python setup.py install
python test.py cuda
Result on Titan V
- Test target: HarDNet85 with 512x512 input size
Torch Conv: 6.489803947205326 ms
CatConv: 5.862883790766858 ms lantency reduction: 9.660078509897597 %
- Applied on FCHarDNet gives a signifcant speed up (70 fps -> 99 fps)