Large language model distributed training

Large Language Models
Author

Alex Chen

Published

March 2, 2024

Tip

The AWS sagemaker is a service to support the automatic training for the models. And the price is 1.5x of the normal elastic container. Thus, the distributed learning is important and expensive.

Distributed learning introduction in Pytorch

We need to be aware of what kind of distributed learning we can use, and there are DDP, RPC and Collective communication from the pytorch documentation (read the documentation for the detail).

Data Parallel

DistributedDataParallel is better than the DataParallel (DP), since DP is limited by the GIL. For DP, it is to split the dataset into multiple machine, and compute them then reduce them. Suppose you have a forward computation with batch size as 16, and the number of the GPU is 4. Then, you basically calculate batch size 4 in each GPU. To apply it, we just need to add a few code:

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

We don’t need to do any other operation to let it run.

Distributed Data Parallel (DDP)

We need to use the specific module to let it work. This trick can overcome the GIL. A code example can be

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import os
from torch.nn.parallel import DistributedDataParallel as DDP

class SimpleCNN(nn.Module):
    pass 

def example(rank, world_size):
    # create default process group, nccl means running on GPU
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    # create local model and move it to the current device (GPU/CPU)
    model = SimpleCNN().to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    # forward pass
    outputs = ddp_model(torch.randn(64, 1, 28, 28).to(rank))  # Example input size for MNIST
    labels = torch.randint(0, 10, (64,)).to(rank)  # Example labels for 64 samples
    # backward pass
    loss_fn(outputs, labels).backward()
    # update parameters
    optimizer.step()

def main():
    world_size = 2
    mp.spawn(example,
             args=(world_size,),
             nprocs=world_size,
             join=True)

if __name__=="__main__":
    # Environment variables for distributed training
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    main()
Tip

rank and world_size are two special concept in the distributed learning. When we launch multiple processes to learn the model, the total number of processes is world_size. For each process, we can define it as rank. You can imagine rank is like a small device, so we put the model or data to the rank like we put them in cuda.

Tip

We still need to define the env value for the distributed learning code, since the framework needs to setup a communication network.

Use the ZeroRedundancyOptimizer

Since some optimizer like Adam will keep many states, usually twice the model size, OOM error can occur. Therefore, we consider to deepspeed optimizer. In pytorch, it is already implemented!

from torch.distributed.optim import ZeroRedundancyOptimizer

If we want to use it, just add a flag called use_zero:

if use_zero:
    optimizer = ZeroRedundancyOptimizer(
        ddp_model.parameters(),
        optimizer_class=torch.optim.Adam,
        lr=0.01
    )
else:
    optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)

This technique is mainly used to distribute the optimizer to multiple machine to avoid the OOM. All other code is similar to the DDP part.

Fully sharded data parallel

This FSDP will distribute the model and data across all process, and it is good especially for the model that can’t be fitted to one GPU. For the example script, refer to this code example.

torchrun

torchrun is a method to execute the distributed learning in a way of elastic running. It can deal with the case that some node may fail. And it can handle the restart automatically.

We should set the checkpoint so that we will at most lose one epoch of training. The code is like

def main():
     args = parse_args(sys.argv[1:])
     state = load_checkpoint(args.checkpoint_path)
     initialize(state)

     # torch.distributed.run ensures that this will work
     # by exporting all the env vars needed to initialize the process group
     torch.distributed.init_process_group(backend=args.backend)

     for i in range(state.epoch, state.total_num_epochs)
          for batch in iter(state.dataset)
              train(batch, state.model)

          state.epoch += 1
          save_checkpoint(state)

For more usage about the torchrun, refer to this page. Here is another script that can be runned by the torchrun command. If we want to run the torchrun, we should firstly make sure the script can adapt to the torchrun. The code is to run it is:

torchrun
   --nnodes=NUM_NODES
   --nproc-per-node=TRAINERS_PER_NODE
   --max-restarts=NUM_ALLOWED_FAILURES
   --rdzv-id=JOB_ID
   --rdzv-backend=c10d
   --rdzv-endpoint=HOST_NODE_ADDR
   YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

For more complicated case, pytorch also provide some solution to use multiple container with communication by docker example or k8s example.