In this project, I'm going to improve the runtime (latency) of batch inference and training of this particular ResNet (code copied from here) on Volta and Ampere architectures. I expect most of the speedup will come from effecitve usage of Tensor cores as well as opertator fusion. My plan is to implement each resnet block (including FixUp and Squeeze-and-Excitation, for both forward and backward) using CUTLASS. I believe its support for Tensor Core, GEMM epilogue, and fused convolution+convolution features will be helpful. The motivation comes from this paper.
-
Inputs:
- For inference: a (N, 3, 64, 64) tensor (corresponding to N RGB images) or a (N, 1, 64, 64) tensor (corresponding to N depth map). Here N is the batch size, and is fixed for each GPU type. In the paper above N = 128 on a Tesla V100. In particular I will not change N to improve the throughput of the network.
For training: a (16 * N, 3, 64, 64) tensor (16 * N = N agents * 32 rollout length / 2 mini-batch), or a (16 * N, 1, 64, 64) tensor.Didn't have enough time to do this...
-
Outputs:
- For inference it is the output of the network: a tensor of size (N, 512, 2, 2).
- For training the output are the gradient (or a gradient update?) of all parameters of this network evaluated at the minibatch. I don't think the intermediate values in the layers of the network is cached at inference time, since the GPU resources needs to be used for rendering. Therefore the forward values needs to be computed here, too.
-
Constraint: It is possible that Pytorch/cuDNN already selects good kernels so there is very limited room for improvement.
-
Task List:
- Profile the kernels used in Pytorch to identify the "hotspots" (pretty sure they are the convolutions based on the output of this script, ran on my RTX 3070). The Pytorch implementation will be the baseline.
- Get familiar with CUTLASS. Implement an unoptimized ResNet block using CUTLASS.
- Make the implementation from the last step a Pytorch operator. I will follow this tutorial. Confirm that the results are correct.
- Optimize. The CUTLASS profiler might be handy for this step.
-
Expected deliverables. An optimzied version of teh ResNet, available as a child class of nn.Module. A barplot that demonstrates the runtime reduction of the optimzied implementation against the baseline.
-
Dependencies: most of them are mentioned above. To summarize, the dependencies are:
- The started code that defines the network to be optimized.
- CUTLASS and tutorials on how to use it.
- Tutorial on how to write custom Pytorch operators
The performance is defined to be the average latency of batch inference evaluated over 10000 inferences with batch size N = 128 per batch inference, ran on an RTX3070. The baseline is a JIT compiled Pytorch implementation (Pytorch built for CUDA 11.1). The kernels in my implementation are all built with CUDA 11.3. See here for more details.
- Baseline: 2.77501 milliseconds
- 2efee8a: 0.954 millisecond. This version fuses the memory bandwidth bound Squeeze-and-Excite block and uses CUTLASS implicit GEMM to compute the covolutions.