Optimizer State Sharding - Amazon SageMaker AI

Optimizer State Sharding

Optimizer state sharding is a useful memory-saving technique that shards the optimizer state (the set of weights that describes the state of optimizer) across data parallel device groups. You can use optimizer state sharding whenever you use a stateful optimizer (such as Adam) or an FP16 optimizer (which stores both FP16 and FP32 copies of the parameters).

Note

Optimizer state sharding is available for PyTorch in the SageMaker model parallelism library v1.6.0 and later.

How to Use Optimizer State Sharding

You can turn on optimizer state sharding by setting "shard_optimizer_state": True in the modelparallel configuration.

When this feature is turned on, the library partitions the set of model parameters based on the data parallelism degree. The gradients corresponding to the ith partition get reduced only at the ith data parallel rank. At the end of the first call to an smp.step decorator function, the optimizer wrapped by smp.DistributedOptimizer redefines its parameters to be only limited to those parameters corresponding to the partition of the current data parallel rank. The redefined parameters are called virtual parameters and share underlying storage with the original parameters. During the first call to optimizer.step, the optimizer states are created based on these redefined parameters, which are sharded because of the original partition. After the optimizer update, the AllGather operation (as part of the optimizer.step call) runs across the data parallel ranks to achieve consistent parameter states.

Tip

Optimizer state sharding can be useful when the degree of data parallelism is greater than 1 and the model has more than a billion parameters.

The degree of data parallelism is calculated by (processes_per_host * instance_count / pipeline_parallel_degree), and the smp.dp_size() function handles the sizing in the background.

Configure a SageMaker PyTorch estimator

mpi_options = { "enabled" : True, "processes_per_host" : 8, # 8 processes "custom_mpi_options" : "--mca btl_vader_single_copy_mechanism none " } smp_options = { "enabled":True, "parameters": { "microbatches": 4, "pipeline_parallel_degree": 2, # alias for "partitions" "placement_strategy": "cluster", "tensor_parallel_degree": 2, # tp over 2 devices "ddp": True, "shard_optimizer_state": True } }

Adapt your PyTorch training script

See Adapt your PyTorch training script in the Tensor parallelism combined with pipeline parallelism section. There’s no additional modification required for the script.