Modify a PyTorch Training Script
In this section, you learn how to modify PyTorch training scripts to configure the SageMaker model parallelism library for auto-partitioning and manual partitioning.
Note
To find which PyTorch versions are supported by the library, see Supported Frameworks and AWS Regions.
Tip
For end-to-end notebook examples that demonstrate how to use a PyTorch training script with the SageMaker model parallelism library, see Amazon SageMaker AI model parallelism library v1 examples.
Note that auto-partitioning is enabled by default. Unless otherwise specified, the following scripts use auto-partitioning.
Topics
Automated splitting with PyTorch
The following training script changes are required to run a PyTorch training script with SageMaker's model parallelism library:
-
Import and initialize the library with
smdistributed.modelparallel.torch.init()
. -
Wrap the model with
smdistributed.modelparallel.torch.DistributedModel
. Be mindful that any tensors returned from the forward
method of the underlyingnn.Module
object will be broadcast across model-parallel devices, incurring communication overhead, so any tensors that are not needed outside the call method (such as intermediate activations) should not be returned.Note
For FP16 training, you need to use the smdistributed.modelparallel.torch.model_creation()
context manager to wrap the model. For more information, see FP16 Training with Model Parallelism. -
Wrap the optimizer with
smdistributed.modelparallel.torch.DistributedOptimizer
. Note
For FP16 training, you need to set up static or dynamic loss scaling. For more information, see FP16 Training with Model Parallelism.
-
Use the returned
DistributedModel
object instead of a user model. -
Put the forward and backward logic in a step function and decorate it with
smdistributed.modelparallel.torch.step
. -
Restrict each process to its own device through
torch.cuda.set_device(smp.local_rank())
. -
Move the input tensors to the GPU using the
.to()
API before thesmp.step
call (see example below). -
Replace
torch.Tensor.backward
andtorch.autograd.backward
withDistributedModel.backward
. -
Perform post-processing on the outputs across microbatches using
StepOutput
methods such as reduce_mean
. -
If there is an evaluation step, similarly place the forward logic inside an
smp.step
-decorated function and post-process the outputs usingStepOutput
API. -
Set
drop_last=True
inDataLoader
. Alternatively, manually skip a batch in the training loop if the batch size is not divisible by the number of microbatches.
To learn more about the SageMaker's model parallelism library API, refer to the API documentation
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchnet.dataset import SplitDataset from torchvision import datasets import smdistributed.modelparallel.torch as smp class GroupedNet(nn.Module): def __init__(self): super(GroupedNet, self).__init__() # define layers def forward(self, x): # define forward pass and return model outputs # smdistributed: Define smp.step. Return any tensors needed outside. @smp.step def train_step(model, data, target): output = model(data) loss = F.nll_loss(output, target, reduction="mean") model.backward(loss) return output, loss def train(model, device, train_loader, optimizer): model.train() for batch_idx, (data, target) in enumerate(train_loader): # smdistributed: Move input tensors to the GPU ID used by the current process, # based on the set_device call. data, target = data.to(device), target.to(device) optimizer.zero_grad() # Return value, loss_mb is a StepOutput object _, loss_mb = train_step(model, data, target) # smdistributed: Average the loss across microbatches. loss = loss_mb.reduce_mean() optimizer.step() # smdistributed: initialize the backend smp.init() # smdistributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") # smdistributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time dataset = datasets.MNIST("../data", train=True, download=False) # smdistributed: Shard the dataset based on data-parallel ranks if smp.dp_size() > 1: partitions_dict = {f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size())} dataset = SplitDataset(dataset, partitions=partitions_dict) dataset.select(f"{smp.dp_rank()}") # smdistributed: Set drop_last=True to ensure that batch size is always divisible # by the number of microbatches train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True) model = GroupedNet() optimizer = optim.Adadelta(model.parameters(), lr=4.0) # smdistributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) optimizer = smp.DistributedOptimizer(optimizer) train(model, device, train_loader, optimizer)
Manual splitting with PyTorch
Use smp.partition
smp.partition
contexts
is placed in the default_partition
. The default_partition
needs to be provided if auto_partition
is set to False
.
The modules that are created within a specific smp.partition
context
are placed on the corresponding partition.
To learn more about the SageMaker's model parallelism library API, refer to the API documentation
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchnet.dataset import SplitDataset from torchvision import datasets import smdistributed.modelparallel.torch as smp class GroupedNet(nn.Module): def __init__(self): super(GroupedNet, self).__init__() with smp.partition(0): # define child modules on device 0 with smp.partition(1): # define child modules on device 1 def forward(self, x): # define forward pass and return model outputs # smdistributed: Define smp.step. Return any tensors needed outside. @smp.step def train_step(model, data, target): output = model(data) loss = F.nll_loss(output, target, reduction="mean") model.backward(loss) return output, loss def train(model, device, train_loader, optimizer): model.train() for batch_idx, (data, target) in enumerate(train_loader): # smdistributed: Move input tensors to the GPU ID used by the current process, # based on the set_device call. data, target = data.to(device), target.to(device) optimizer.zero_grad() # Return value, loss_mb is a StepOutput object _, loss_mb = train_step(model, data, target) # smdistributed: Average the loss across microbatches. loss = loss_mb.reduce_mean() optimizer.step() # smdistributed: initialize the backend smp.init() # smdistributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") # smdistributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time dataset = datasets.MNIST("../data", train=True, download=False) # smdistributed: Shard the dataset based on data-parallel ranks if smp.dp_size() > 1: partitions_dict = {f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size())} dataset = SplitDataset(dataset, partitions=partitions_dict) dataset.select(f"{smp.dp_rank()}") # smdistributed: Set drop_last=True to ensure that batch size is always divisible # by the number of microbatches train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True) model = GroupedNet() optimizer = optim.Adadelta(model.parameters(), lr=4.0) # smdistributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) optimizer = smp.DistributedOptimizer(optimizer) train(model, device, train_loader, optimizer)
Considerations
When you configure a PyTorch training script using SageMaker's model parallelism library, you should be aware of the following:
-
If you are using an optimization technique that relies on global gradient norms, for example gradient norm from the entire model, such as some variants of LAMB optimizer or global gradient clipping, you need to gather all the norms across the model partitions for correctness. You can use the library’s communication basic data types to do this.
-
All
torch.Tensor
arguments to the forward methods of thenn.Modules
in your model must be used in the computation of the module output. In other words, the library does not support the case where there is atorch.Tensor
argument to a module on which the module output does not depend. -
The argument to the
smp.DistributedModel.backward()
call must depend on all model outputs. In other words, there cannot be an output from thesmp.DistributedModel.forward
call that is not used in the computation of the tensor that is fed into thesmp.DistributedModel.backward
call. -
If there are
torch.cuda.synchronize()
calls in your code, you might need to calltorch.cuda.set_device(smp.local_rank())
immediately before the synchronize call. Otherwise unnecessary CUDA contexts might be created in device 0, which will needlessly consume memory. -
Since the library places
nn.Modules
on different devices, the modules in the model must not depend on any global state that is modified insidesmp.step
. Any state that remains fixed throughout training, or that is modified outsidesmp.step
in a way that is visible to all processes, is allowed. -
You don’t need to move the model to GPU (for example, using
model.to(device)
) when using the library. If you try to move the model to GPU before the model is partitioned (before the firstsmp.step
call), the move call is ignored. The library automatically moves the part of the model assigned to a rank to its GPU. Once training with the library starts, don’t move the model to CPU and use it, as it won’t have correct parameters for modules not assigned to the partition held by the process. If you want to retrain a model or use it for inference without the library after it was trained using the model parallelism library, the recommended way is to save the full model using our checkpointing API and load it back to a regular PyTorch Module. -
If you have a list of modules such that output of one feeds into another, replacing that list with
nn.Sequential
can significantly improve performance. -
The weight update (
optimizer.step()
) needs to happen outside ofsmp.step
because that is when the entire backward pass is done and gradients are ready. When using a hybrid model with model and data parallelism, at this point, AllReduce of gradients is also guaranteed to finish. -
When using the library in combination with data parallelism, make sure that the number of batches on all data parallel ranks is the same so that AllReduce does not hang waiting for a rank which is not participating in the step.
-
If you launch a training job using an ml.p4d instance type (such as ml.p4d.24xlarge), you must set the data loader variable
num_workers=0
. For example, you may define yourDataLoader
as follows:dataloader = torch.utils.data.DataLoader( data, batch_size=batch_size, num_workers=0, pin_memory=True, drop_last=True, shuffle=shuffle, )
-
The inputs to
smp.step
must be the model inputs generated byDataLoader
. This is becausesmp.step
internally splits the input tensors along the batch dimension and pipelines them. This means that passingDataLoader
itself to thesmp.step
function to generate the model inputs inside does not work.For example, if you define a
DataLoader
as follows:train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True)
You should access the model inputs generated by
train_loader
and pass those to ansmp.step
decorated function. Do not passtrain_loader
directly to thesmp.step
function.def train(model, device, train_loader, optimizer): model.train() for batch_idx, (data, target) in enumerate(train_loader): ... _, loss_mb = train_step(model, data, target) ... @smp.step def train_step(model, data, target): ... return output, loss
-
The input tensors to
smp.step
must be moved to the current device using.to()
API, which must take place after thetorch.cuda.set_device(local_rank())
call.For example, you may define the
train
function as follows. This function addsdata
andtarget
to the current device using.to()
API before using those input tensors to calltrain_step
.def train(model, device, train_loader, optimizer): model.train() for batch_idx, (data, target) in enumerate(train_loader): # smdistributed: Move input tensors to the GPU ID used by the current process, # based on the set_device call. data, target = data.to(device), target.to(device) optimizer.zero_grad() # Return value, loss_mb is a StepOutput object _, loss_mb = train_step(model, data, target) # smdistributed: Average the loss across microbatches. loss = loss_mb.reduce_mean() optimizer.step()
The input tensors to this
smp.set
decorated function have been moved to the current device in thetrain
function above. The model does not need to be moved to the current device. The library automatically moves the part of the model assigned to a rank to its GPU.@smp.step def train_step(model, data, target): output = model(data) loss = F.nll_loss(output, target, reduction="mean") model.backward(loss) return output, loss
Unsupported framework features
The following PyTorch features are unsupported by SageMaker's model parallelism library:
-
If you use data parallelism with the native PyTorch DDP
, the torch.nn.parallel.DistributedDataParallel
wrapper module is not supported by the library. The library internally manages integrating with PyTorch DDP, including parameter broadcast and gradient AllReduce. When using the library, module buffers are only broadcast once at the start of training. If your model has module buffers that need to be synchronized across data parallel groups at each step, you can do so through the torch.distributed
API, using the process group that can be obtained viasmp.get_dp_process_group()
. -
For mixed precision training, the
apex.amp
module is not supported. The recommended way to use the library with automatic mixed-precision is to usetorch.cuda.amp
, with the exception of usingsmp.amp.GradScaler
instead of the implementation in torch. -
torch.jit.ScriptModules
orScriptFunctions
are not supported bysmp.DistributedModel
. -
apex
:FusedLayerNorm
,FusedAdam
,FusedLAMB
, andFusedNovoGrad
fromapex
are not supported. You can use the library implementations of these throughsmp.optimizers
andsmp.nn
APIs instead.