Profiling GPT2 Inference Latency (FP32)

December 02, 2022

I ask the question here: does a large GPT-2 model inference fully utilize the hardware capabilities of an Nvidia A100? Does this expensive hardware pay off for running inference Transformer architecture based inference? Large Language Models (LLMs) like GPT-2 have generated significant activity in optimizing machine learning inference along two major axes: cost of inference (largely driven by hardware or cloud cost), and latency of inference. The default choice for running inference has mostly been NVIDIA GPUs which are among the most expensive hardware components available for the task. Since researchers are investigating transformer architectures now to complement or even replace traditional architectures like convolutional networks I think it’s worthwhile to briefly consider what the performance characteristics of inference on these are like, to serve as comparison for the LLM.

At small batch sizes a typical convolutional neural networks tend to underutilize compute capabilities of high-end GPUs but still benefit in latency from high memory bandwidth. The reason are small kernels with low occupancy can bottleneck the critical path resulting in bursts of compute intensity followed by low intensity (i.e. underutilization) for example the trace below shows the stream of kernels (zoomed out a bit too far to read which kernels they are) and some high-level performance metrics such as occupancy,compute and memory throughput

image
Inference of efficientnet_v2_l model with batch=1 on NVIDIA A100 (TRT engine)

Thus this brings a fair question on the tradeoff between cost and latency. If one expects to mostly do inference at small batch sizes (which is a common use case), could we perhaps use a less capable but cheaper machine to deliver the same results? We can’t answer this in a vacuum and need to know about the application, but for small batch inference of convolutional neural networks very high-end GPUs do not deliver much extra value compared to their substantially increased cost. But what about LLMs?

LLMs change the game in many respects compared to their predecessors. Their models tend to have many more parameters than convolutional neural networks, so much that they often will not fully fit within the memory of a single GPU. Since GPU memory capacity tends to only increase significantly on the much higher-end GPUs this has pushed LLM inference into those GPUs, driving up costs. Transformer architectures mark another departure and important factor for evaluating performance. Convolutional layers themselves need to process relatively little amount of data (the convolution kernel,input, and output) and any data increase usually comes from increasing number of channels from the convolutional layer (a linear dependence). Multi-headed attention on the other hand layers must compute (comparatively) large matrix-matrix multiplies - processing query,key, and value matrices, inputs and outputs. The amount of data processed grows like the square of the input sequence length. Is it possible then that LLMs may see kernels with greater occupancy or greater compute intensity and fewer critical path bottlenecks of low intensity such as seen with convolutional networks, leading to better use on high end GPUS? I investigate below

GPT-2 Small Batch inference on NVIDIA A100

Basic Timing

I will benchmark GPT2 inference of a single batch of size 1024 tokens.

I used a prebaked nvidia image on GCP with an nvidia A100 - screengrab of details belowimage.

First I generated a TensorRT engine from this model and did the high-level benchmark (no hardware measurements, just times):

[12/04/2022-02:01:36] [I] === Trace details ===
[12/04/2022-02:01:36] [I] Trace averages of 10 runs:
[12/04/2022-02:01:36] [I] Average on 10 runs - GPU latency: 31.0561 ms - Host latency: 41.2429 ms (enqueue 0.151791 ms)
[12/04/2022-02:01:36] [I] Average on 10 runs - GPU latency: 31.0865 ms - Host latency: 41.274 ms (enqueue 0.148151 ms)
[12/04/2022-02:01:36] [I] Average on 10 runs - GPU latency: 31.0933 ms - Host latency: 41.2813 ms (enqueue 0.149274 ms)
[12/04/2022-02:01:36] [I] Average on 10 runs - GPU latency: 31.0886 ms - Host latency: 41.2748 ms (enqueue 0.137341 ms)
[12/04/2022-02:01:36] [I] Average on 10 runs - GPU latency: 31.0964 ms - Host latency: 41.2835 ms (enqueue 0.13833 ms)
[12/04/2022-02:01:36] [I] Average on 10 runs - GPU latency: 31.0831 ms - Host latency: 41.2698 ms (enqueue 0.155908 ms)
[12/04/2022-02:01:36] [I] Average on 10 runs - GPU latency: 31.086 ms - Host latency: 41.2733 ms (enqueue 0.140601 ms)
[12/04/2022-02:01:36] [I] Average on 10 runs - GPU latency: 31.0899 ms - Host latency: 41.2763 ms (enqueue 0.148071 ms)
[12/04/2022-02:01:36] [I] Average on 10 runs - GPU latency: 31.0868 ms - Host latency: 41.2747 ms (enqueue 0.151758 ms)
[12/04/2022-02:01:36] [I] 
[12/04/2022-02:01:36] [I] === Performance summary ===
[12/04/2022-02:01:36] [I] Throughput: 31.8446 qps
[12/04/2022-02:01:36] [I] Latency: min = 41.1393 ms, max = 41.3042 ms, mean = 41.273 ms, median = 41.2759 ms, percentile(90%) = 41.2895 ms, percentile(95%) = 41.2952 ms, percentile(99%) = 41.3042 ms
[12/04/2022-02:01:36] [I] Enqueue Time: min = 0.120361 ms, max = 0.22583 ms, mean = 0.148609 ms, median = 0.146973 ms, percentile(90%) = 0.173096 ms, percentile(95%) = 0.180908 ms, percentile(99%) = 0.22583 ms
[12/04/2022-02:01:36] [I] H2D Latency: min = 0.00842285 ms, max = 0.0283203 ms, mean = 0.0102436 ms, median = 0.00927734 ms, percentile(90%) = 0.012207 ms, percentile(95%) = 0.0171509 ms, percentile(99%) = 0.0283203 ms
[12/04/2022-02:01:36] [I] GPU Compute Time: min = 30.9525 ms, max = 31.1133 ms, mean = 31.0858 ms, median = 31.0886 ms, percentile(90%) = 31.104 ms, percentile(95%) = 31.1082 ms, percentile(99%) = 31.1133 ms
[12/04/2022-02:01:36] [I] D2H Latency: min = 10.1726 ms, max = 10.1826 ms, mean = 10.1769 ms, median = 10.1768 ms, percentile(90%) = 10.1783 ms, percentile(95%) = 10.1787 ms, percentile(99%) = 10.1826 ms
[12/04/2022-02:01:36] [I] Total Host Walltime: 3.10885 s
[12/04/2022-02:01:36] [I] Total GPU Compute Time: 3.0775 s
[12/04/2022-02:01:36] [I] Explanations of the performance metrics are printed in the verbose logs.

So GPU time is about 31ms for a single inference. I disregard PCI-E transfers here because in practice one can usually overlap these with application logic by way of the GPU DMA engine so that they contribute very little or no “end-to-end” latency to the application.

Deeper Perf Analysis

Next, using nsight compute and nsight systems I traced and profile the nvidia TRT engines to see what kernels were taking the most time and also to see “overall” how well the engine used the underlying hardware. Combining information from these two sources I looked at Compute and Memory throughput as well as theoretical and actual kernel occupancy on top of an actual kernel trace, producing a trace like below:

image
Inference of GPT2-Large sequence length=1024, batch=1, on NVIDIA A100

full trace here, you can open this with chrome://tracing in Google Chrome.

Three major observations:

  1. Higher sustained compute throughput compared to efficientnet, with fewer bottlenecks.
  2. Actual occupancy matches theoretical occupancy much more frequently than efficientnet
  3. Very high sustained memory throughput

I conclude from this that while the compute throughput is indeed much higher than what a convolutional neural network would deliver, the overriding source of latency benefit from the A100 on this low-batch inference is its high memory bandwidth (approaching 2 TB/s), which it is almost fully utilizing throughout the whole inference.

Since GPUs have a considerable advantantage in memory bandwidth over CPUs I decided to test my theory by running the same inference on an Intel machine.

GPT-2 Small Batch inference on Intel Cascade-Lake

For an Intel machine I used the following:

image

Basic Timing

To get an inference engine optimized for the Intel architecture I used OpenVINO with the following commands

mo --input_model simple_model.onnx 
benchmark_app -m simple_model.xml -hint latency -ip f32 -op f32 -report_type detailed_counters

which resulted in the following (considerably slower) measurements):

Latency:
    Median:     2101.50 ms
    AVG:        2100.02 ms
    MIN:        2085.78 ms
    MAX:        2126.31 ms

Even when we account for the fact that this is an underpowered (and cheaper) system compared to Nvidia we see this is wildly out of proportion with the excellent latency on the A100

Machine type GPT2 Inference Latency Cost ($/month) Latency * Cost
A100 (Nvidia)40ms2000/month80000
Cascade Lake (intel)2000ms500/month1000000

When we control for memory bandwidth however the results make more sense

Machine type Memory Bandwidth (peak) Bandwidth * Inference Latency
A100 (Nvidia)2 TB/s80
Cascade Lake (intel)0.1 TB/s200

This memory bandwidth advantage may disappear in coming months as CPU vendors roll out their on-package High-Bandwidth-Memory (HBM).

Summary and Conclusions

The results above do suggest that GPT2 inference at small batch sizes does indeed make good use of the A100 with some room to improve in compute throughput (perhaps by better tuning the shared memory and/or caching in the attention layers).
However since the A100 has excellent memory bandwidth and GPT2 achieves very high sustained memory throughput we see very good performance on the A100, possibly justifying its cost and use (beyond simply the fact that these instances have the most GPU memory and thus fit bigger LLMs).

Intel Cascade Lake performance was very poor but my findings suggest that the gap may close in coming months as Intel and AMD push to include on-package High-Bandwidth-Memory (HBM) as the performance gap appears to be due to the memory bandwidth differences between these two platforms.

Methods and References

Model & checkpoint used and conversion to onnx

I used huggingface GPT2 checkpoints for gpt2-large and converted this to onnx with the torch onnx module.

The rough script for this was

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, onnx
from torchviz import make_dot
import torch
import torch.nn as nn
import onnx
# Configuration
# Load codeparrot tokenizer trained for Python code tokenization
tokenizer = AutoTokenizer.from_pretrained("codeparrot/codeparrot")
# Config: "scale_attn_by_layer_idx" and "reorder_and_upcast_attn" are Mistral stability tweaks
config_kwargs = {
    "vocab_size": len(tokenizer),
    "scale_attn_by_inverse_layer_idx": True,
    "reorder_and_upcast_attn": True,
}
# Load model config (GPT-2 large in this case)
config = AutoConfig.from_pretrained("gpt2-large", **config_kwargs)
# Initialize new model with config
model = AutoModelForCausalLM.from_config(config)
#Make a fake module to turn off key,value cache so these do not
#get traced during onnx export
class TmpModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.tmp=model
    def forward(self,x):
        return self.tmp(x,use_cache=False)
tmpmodel=TmpModel()
torch.onnx.export(tmpmodel,               # model being run
                  torch.randint(1, len(tokenizer),(1,1024)),  # example input for the model
                  "simple_model.onnx", # where to save the model (can be a file or file-like object)
                  export_params=True,  # store the trained parameter weights inside the model file
                  opset_version=15,    # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names=['input'],   # the model's input names
                  output_names=['output'], # the model's output names
                  training=torch.onnx.TrainingMode.EVAL
                  )

Converting ONNX to TRT Engine

To get fp32 and fp16 engines with extra profiling information like NVTX markers (for use in nsys) I did the following:

trtexec --onnx=simple_model.onnx --saveEngine=gpt2_fp32.engine --profilingVerbosity=detailed --workspace=30000
trtexec --onnx=simple_model.onnx --fp16 --saveEngine=gpt2_fp16.engine --profilingVerbosity=detailed --workspace=30000

Basic timings with trtexec

trtexec --loadEngine=gpt2_fp32.engine --useCudaGraph

Getting nsight systems trace

For the trace I did the following

nsys profile --trace=cuda,cudnn,cublas,osrt,nvtx trtexec --iterations=1 --loadEngine=gpt2_fp32.engine

Getting performance counters with nsight compute

To get performance counters I did

sudo env "PATH=$PATH" ncu --section ComputeWorkloadAnalysis --csv trtexec --loadEngine=gpt2_fp32.engine > cwa.csv
sudo env "PATH=$PATH" ncu --section SpeedOfLight --csv trtexec --loadEngine=gpt2_fp32.engine > sol.csv
sudo env "PATH=$PATH" ncu --section Occupancy --csv trtexec --loadENgine=gpt2_fp32.engine > occupancy.csv

Extracting kernels and start,stop times from nsys systems trace sqlite file

sqlite3 -csv report2.sqlite 'SELECT names.value AS name, start, end FROM CUPTI_ACTIVITY_KIND_KERNEL AS k JOIN StringIds AS names ON k.demangledName = names.id;' > kernels.csv

Combining perf counters and kernels into a single trace file

I combined the perf counter data from nsight compute and the trace of nsight systems to have a visual representation of when the GPU was fully utilized. I did this with a python script you can find here

The resulting file can be opened with chrome://tracing and looks something like below:

image

References