In this blog we will explore Fully Sharded Data Parallelism (FSDP), which is a technique that allows for the training of large Neural Network models in a distributed manner efficiently. We’ll examine FSDP from a bird’s eye view and shed light on the underlying mechanisms.
When choosing a distributed learning approach, it’s important to be aware of the advantages and disadvantages of each strategy in order to implement one that matches the target use-case at hand.
For large language models (LLMs) specifically, the large number of parameters entails grappling with significant GPU memory requirements, and FSDP emerges as a high-performance solution in this context, as it is able to effectively address the significant GPU memory demands. By leveraging multiple GPUs, FSDP offers a practical trade-off, optimizing GPU communication to minimize memory usage.
On the other hand, for Computer Vision models that typically can fit on a single GPU, Distributed Data Parallelism (DDP) often proves more efficient, allowing for the entire model to be run while avoiding the GPU communication overhead associated with FSDP.
Suppose we find ourselves wanting to run a large model that won't fit on a single GPU. The “naive” approach to running the model involves splitting the model in a so-called “vertical” way by allocating different model layers to different GPUs. Each GPU handles a specific set of layers, and the whole model can be run sequentially. Problem solved, right? Unfortunately, that’s not the case—this naive method has significant limitations, and we’ll show that FSDP is an answer to overcoming them.
Let's consider a scenario with n GPUs and a model size denoted as S, in which the model cannot fit on one GPU but can fit on all the GPUs combined. After splitting the model “vertically,” each GPU has a slice of the model of size S/n. The time it takes for a single GPU to run the forward pass on its model slice is represented as T_forward, and the backward run time is denoted as T_backward.
This approach allows for a complete forward pass in just T_forward*n, reflecting the best case of FSDP. Moreover, only the activations are passed, not the model weights, gradients or the optimizer state.
This strategy, at least at a first glance, seems to be a valid way for model training with multiple GPUs, and given the straightforward nature of the approach, one might question the necessity of doing anything more complicated—it seems we have no need for extensive GPU communication, splitting the dataset between the GPUs, or preforming any complicated GPU communication like the operations used in FSDP.
The issue emerges when the first GPU completes its forward pass and intends to run forward pass on the next batch. We cannot begin the next forward pass for the subsequent batches before updating the initial weights of the first GPU model weights—so the first GPU must wait for all downstream GPUs to complete the forward pass and propagate their respective gradients backwards before its weights can be updated, resulting in significant idle time. Specifically, the first GPU waits for T_forward*(n-1) + T_backward*(n-1) before applying backward propagation for the first batch.
We’ve split our model among our GPUs, but most of the time, each GPU is just sitting around doing nothing, waiting for the data to propagate elsewhere in the model!
FSDP offers a solution to this problem, providing the ability to fully make use of all the GPUs on large models without significant idle GPU time.
There are two separate actions that take place to setup the FSDP process:
Just as in the naive solution above, in vertical splitting the model layers are organized into “units.” For example, in a model with 9 convolution layers, each unit is responsible for a specific range of layers. The first GPU unit might manage layers 1 to 3, the second unit, layers 4 to 6, and the last unit might oversee layers 7 to 9.
“Horizontal splitting” refers to the splitting of the model parameters within each layer and storing them on individual GPUs; this process is also commonly called sharding. For example, for 3 fully connected layers, instead of storing each fully connected layer on one GPU, each GPU holds one third of the fully connected layer entities (that is, the parameters, gradients, and optimizer states).
Throughout the training process, collaboration among GPUs occurs as they share necessary shards—in doing so, we store redundant parameters and incur communication overhead between the GPUs, but in doing so, we’re able to keep all the GPUs busy at all times.
All GPUs will run all the units one by one in parallel during forward and backward steps by gathering the necessary shards of model parameters and other entities from other GPUs.
In PyTorch's FSDP, there are multiple configuration settings known as “sharding strategies” that govern the distribution and management of model shards. This blog post will delve into the intricacies of the FULL_SHARD sharding strategy, which is the most memory efficient but communication-intensive strategy.
Under the FULL_SHARD strategy, the following key entities are subjected to sharding:
In our example above, we assumed we have three available GPUs, and we will once again assume that the model cannot fit on one GPU but can fit on all of them combined. The neural network that we will use for our toy example has 9 layers and each unit will be assigned 3 layers. The extension to the more common case where the model can’t fit on all GPUs at once is straightforward once the machinery of this toy example has been explained.
We will define the following terms:
MEM_total: the memory size of all the parameters to be stored, i.e., MP + GRD + OSThe following initial steps set the stage for FSDP:
MEM_total/3 in its memory.💡 As mentioned, the sharding of the model entities is 'horizontal,' meaning that each shard includes model parameters from every layer, as shown in the following diagram:
It's important to note that in PyTorch's FSDP with the 'FULL_SHARD' strategy, both the gradients and optimizer state are shown as they are before the first backward and optimizer steps. At this stage, these entities are not yet calculated; they are placeholders. This implies that they will only contain actual, calculated information after being updated during the backward pass and the optimizer step.
The forward pass consists of the following steps:
💡 In the diagrams, opaque colors indicate that the shard is “owned” by the GPU and will persist throughout the entire training process. Conversely, a low-opacity filling denotes sharded entities not attached to the GPU, and they will be discarded after usage during resharding.
The backward pass consists of the following steps:
This concludes the description of the FSDP training process. Overall, we have seen that this process includes somewhat complicated interaction operations between multiple GPUs, but results in minimal GPU idle time. In this way, FSDP makes the most of available computing resources and allows the training of large models in an efficient manner in a server-side training environment.
In our next blog post, we will delve into the practical aspects of FSDP including code snippets from PyTorch FSDP, guiding you through the process of training a model using FSDP on GCP. Stay tuned for hands-on insights and implementation details!