< Return to home

Train SOTA? Not so fast.


A lot of people want to be able to train their own multimodal large language models (MLLMs). Assuming you got your data from Sweatshop AI and your GPU from Jensen, you still can't really get started. This is because of two main reasons:

  1. It is no longer possible to fit the parameters of these models in the main memory of even the largest GPUs (80BG-A100 cards still aren't cutting it).
  2. Even if you could fit the model in a single GPU by swapping parameters between host and device memory, the high number of compute operations required results in extremely long training times (e.g. training GPT-3 with 175 billion params would require 288 years with a single V100).
Naturally we'll try and throw parallelism and multinode clusters at this problem. For the most part, data-parallel scale out tends to work well for models that are particularly parameter efficient (e.g. they have high ratio of FLOPS per forward pass / # of parameters). Unfortunately, this technique also has its limitations. Particularly as the number of GPUs increases, the communication overhead kills performance as more machines sit idle waiting for gradient updates from Ring All-Reduce or parameter servers.

In recent years, methods have been proposed to try and address this downfall like tensor (intra-layer) model parallelism, where matrix multiplications within each transformer layer are split over multiple GPUs (see Megatron). Although this works well for models of sizes up to 20 billion parameters on NVIDIA DGX A100 servers (with 8 x 80GB-A100 GPUs), it still breaks down for larger models. So lets try and do the logical thing and split the model across multiple multi-GPU servers. This also has drawbacks:

  1. The All-Reduce communication required for tensor parallelism needs to go through inter-server links which are slower than the high bandwidth NVLink available within a multi-GPU server, some techniques however have been tried to resolve this (see DiLoCo).

  2. A high degree of model parallelism can create small matrix multiplications, decreasing GPU utilization (see tile efficiency). This is because if we have a large model that fully utilizes a 1 GPU during a GEMM, splitting the model across 4 GPUs shrinks the size of the GEMM on each individual GPU leading to machines with underutilized resources. Furthermore, if the model sharding (partitioning) results in smaller matrices or sub-operations this can lead to more frequent GEMMs on each device, which are less efficient for the GPU to compute.
Pipeline parallelism is another technique used to support the training of larger models where the model is divided into different stages, and each stage consists of a shard of the model's layers. These are then distributed across multiple GPUs (e.g if you have a model with 12 layers and 3 GPUs, you might assign 4 layers to each GPU or some different arrangement). Layers can be assigned to workers in various ways, and various schedules for the forward and backward passes can be used. The layer assignment and scheduling strategy results in different performance tradeoffs. Regardless of the schedule chosen, the optimizer steps need to be synchronized across devices leading to a pipeline flush at the end of every batch, where:
  1. Ongoing minibatches are allowed to complete execution.
  2. No new minibatches are injected into the pipeline.
The pipeline flush can consume up to 50% of the total training time, depending on the number of microbatches injected into the pipeline. These strict optimizer semantics ensures training consistency at the cost of performance. Researchers have tried to relax them to improve throughput, but it seems to be at the cost of convergence / final model accuracy (see PipeDream-2BW).

I'd like to think that for most people reading thus far have come to the realization that there are too many modes of parallelism to handle. Maybe it'd just be better to try and combine all modes of parallelism through a hybrid approach. Unfortunately integrating pipeline, data, tensor parallelism haphazardly causes non-trivial interactions such as a direct impact on the amount of communication, the compute efficiency with which kernels are executed, idle time workers spend waiting for computation due to pipeline flushes referred to as pipeline bubbles. This can all lead to an over 2x lower throughput even with high-bandwidth between servers with poor combinations of parallelism.

Everything I've shared thus far only barely scratches the surface of large scale distributed training without even touching on the realistic considerations to take this knowledge into production like budget, network topology, constantly changing model architectures, etc. So what's the solution to figure out how to achieve high throughput + low communication + cost efficient training? Can we also mirror this paradigm for inference?

I got some ideas.

Always laugh at the competition.

- hyena

09/04/2024