

# Step 1: Modify Your Own Training Script Using SageMaker's Distributed Model Parallel Library
<a name="model-parallel-customize-training-script"></a>

Use this section to learn how to customize your training script to use the core features of the Amazon SageMaker AI model parallelism library. To use the library-specific API functions and parameters, we recommend you use this documentation alongside the [SageMaker model parallel library APIs](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smd_model_parallel.html) in the *SageMaker Python SDK documentation*.

The training script examples provided in these sections are simplified and designed to highlight the required changes you must make to use the library. For end-to-end, runnable notebook examples that demonstrate how to use a TensorFlow or PyTorch training script with the SageMaker model parallelism library, see [Amazon SageMaker AI model parallelism library v2 examples](distributed-model-parallel-v2-examples.md).

**Topics**
+ [Split the model of your training script using the SageMaker model parallelism library](#model-parallel-model-splitting-using-smp-lib)
+ [Modify a TensorFlow training script](model-parallel-customize-training-script-tf.md)
+ [Modify a PyTorch Training Script](model-parallel-customize-training-script-pt.md)

## Split the model of your training script using the SageMaker model parallelism library
<a name="model-parallel-model-splitting-using-smp-lib"></a>

There are two ways to modify your training script to set up model splitting: automated splitting or manual splitting.

### Automated model splitting
<a name="model-parallel-automated-model-splitting"></a>

When you use SageMaker's model parallelism library, you can take advantage of *automated model splitting*, also referred to as *automated model partitioning*. The library uses a partitioning algorithm that balances memory, minimizes communication between devices, and optimizes performance. You can configure the automated partitioning algorithm to optimize for speed or memory. 

Alternatively, you can use manual model splitting. We recommend automated model splitting, unless you are very familiar with the model architecture and have a good idea of how to efficiently partition your model.

#### How it works
<a name="model-parallel-automated-model-splitting-how-it-works"></a>

Auto-partitioning occurs during the first training step, when the `smp.step`-decorated function is first called. During this call, the library first constructs a version of the model on the CPU RAM (to avoid GPU memory limitations), and then analyzes the model graph and makes a partitioning decision. Based on this decision, each model partition is loaded on a GPU, and only then the first step is executed. Because of these analysis and partitioning steps, the first training step might take longer. 

In either framework, the library manages the communication between devices through its own backend, which is optimized for AWS infrastructure.

The auto-partition design adapts to the characteristics of the framework, and the library does the partitioning at the granularity level that is more natural in each framework. For instance, in TensorFlow, each specific operation can be assigned to a different device, whereas in PyTorch, the assignment is done at the module level, where each module consists of multiple operations. The follow section reviews the specifics of the design in each framework.

##### Automated model splitting with PyTorch
<a name="model-parallel-auto-model-split-pt"></a>

During the first training step, the model parallelism library internally runs a tracing step that is meant to construct the model graph and determine the tensor and parameter shapes. After this tracing step, the library constructs a tree, which consists of the nested `nn.Module` objects in the model, as well as additional data gathered from tracing, such as the amount of stored `nn.Parameters`, and execution time for each `nn.Module`. 

Next, the library traverses this tree from the root and runs a partitioning algorithm that assigns each `nn.Module` to a device, which balances computational load (measured by module execution time) and memory use (measured by the total stored `nn.Parameter` size and activations). If multiple `nn.Modules` share the same `nn.Parameter`, then these modules are placed on the same device to avoid maintaining multiple versions of the same parameter. Once the partitioning decision is made, the assigned modules and weights are loaded to their devices.

For instructions on how to register the `smp.step` decorator to your PyTorch training script, see [Automated splitting with PyTorch](model-parallel-customize-training-script-pt.md#model-parallel-customize-training-script-pt-16).

##### Automated model splitting with TensorFlow
<a name="model-parallel-auto-model-split-tf"></a>

The model parallelism library analyzes the sizes of the trainable variables and the graph structure, and internally uses a graph partitioning algorithm. This algorithm comes up with a device assignment for each operation, with the objective of minimizing the amount of communication needed across devices, subject to two constraints: 
+ Balancing the number of variables stored in each device
+ Balancing the number of operations executed in each device

If you specify `speed` for `optimize` (in the model parallelism parameters in the Python SDK), the library tries to balance the number of operations and `tf.Variable` objects in each device. Otherwise, it tries to balance the total size of `tf.Variables`.

Once the partitioning decision is made, the library creates a serialized representation of the subgraph that each device needs to execute and imports them onto each device. While partitioning, the library places operations that consume the same `tf.Variable` and operations that are part of the same Keras layer onto the same device. It also respects the colocation constraints imposed by TensorFlow. This means that, for example, if there are two Keras layers that share a `tf.Variable`, then all operations that are part of these layers are placed on a single device.

For instructions on how to register the `smp.step` decorator to your PyTorch training script, see [Automated splitting with TensorFlow](model-parallel-customize-training-script-tf.md#model-parallel-customize-training-script-tf-23).

##### Comparison of automated model splitting between frameworks
<a name="model-parallel-auto-model-split-comparison"></a>

In TensorFlow, the fundamental unit of computation is a `tf.Operation`, and TensorFlow represents the model as a directed acyclic graph (DAG) of `tf.Operation`s, and therefore the model parallelism library partitions this DAG so that each node goes to one device. Crucially, `tf.Operation` objects are sufficiently rich with customizable attributes, and they are universal in the sense that every model is guaranteed to consist of a graph of such objects. 

PyTorch on the other hand, does not have an equivalent notion of operation that is sufficiently rich and universal. The closest unit of computation in PyTorch that has these characteristics is an `nn.Module`, which is at a much higher granularity level, and this is why the library does partitioning at this level in PyTorch.

### Manual Model Splitting
<a name="model-parallel-manual-model-splitting"></a>

If you want to manually specify how to partition your model across devices, use the `smp.partition` context manager. For instructions on how to set the context manager for manual partitioning, see the following pages.
+ [Manual splitting with TensorFlow](model-parallel-customize-training-script-tf.md#model-parallel-customize-training-script-tf-manual)
+ [Manual splitting with PyTorch](model-parallel-customize-training-script-pt.md#model-parallel-customize-training-script-pt-16-hvd)

To use this option after making modifications, in Step 2, you'll need to set `auto_partition` to `False`, and define a `default_partition` in the framework estimator class of the SageMaker Python SDK. Any operation that is not explicitly placed on a partition through the `smp.partition` context manager is executed on the `default_partition`. In this case, the automated splitting logic is bypassed, and each operation is placed based on your specification. Based on the resulting graph structure, the model parallelism library creates a pipelined execution schedule automatically.

# Modify a TensorFlow training script
<a name="model-parallel-customize-training-script-tf"></a>

In this section, you learn how to modify TensorFlow training scripts to configure the SageMaker model parallelism library for auto-partitioning and manual partitioning. This selection of examples also includes an example integrated with Horovod for hybrid model and data parallelism.

**Note**  
To find which TensorFlow versions are supported by the library, see [Supported Frameworks and AWS Regions](distributed-model-parallel-support.md).

The required modifications you must make to your training script to use the library are listed in [Automated splitting with TensorFlow](#model-parallel-customize-training-script-tf-23).

To learn how to modify your training script to use hybrid model and data parallelism with Horovod, see [Automated splitting with TensorFlow and Horovod for hybrid model and data parallelism](#model-parallel-customize-training-script-tf-2.3).

If you want to use manual partitioning, also review [Manual splitting with TensorFlow](#model-parallel-customize-training-script-tf-manual). 

The following topics show examples of training scripts that you can use to configure SageMaker's model parallelism library for auto-partitioning and manual partitioning TensorFlow models. 

**Note**  
Auto-partitioning is enabled by default. Unless otherwise specified, the example scripts use auto-partitioning.

**Topics**
+ [Automated splitting with TensorFlow](#model-parallel-customize-training-script-tf-23)
+ [Automated splitting with TensorFlow and Horovod for hybrid model and data parallelism](#model-parallel-customize-training-script-tf-2.3)
+ [Manual splitting with TensorFlow](#model-parallel-customize-training-script-tf-manual)
+ [Unsupported framework features](#model-parallel-tf-unsupported-features)

## Automated splitting with TensorFlow
<a name="model-parallel-customize-training-script-tf-23"></a>

The following training script changes are required to run a TensorFlow model with SageMaker's model parallelism library:

1. Import and initialize the library with [https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#smp.init](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#smp.init).

1. Define a Keras model by inheriting from [https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_tensorflow.html](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_tensorflow.html) instead of the Keras Model class. Return the model outputs from the call method of the `smp.DistributedModel` object. Be mindful that any tensors returned from the call method will be broadcast across model-parallel devices, incurring communication overhead, so any tensors that are not needed outside the call method (such as intermediate activations) should not be returned.

1. Set `drop_remainder=True` in `tf.Dataset.batch()` method. This is to ensure that the batch size is always divisible by the number of microbatches.

1. Seed the random operations in the data pipeline using `smp.dp_rank()`, e.g., `shuffle(ds, seed=smp.dp_rank())` to ensure consistency of data samples across GPUs that hold different model partitions.

1. Put the forward and backward logic in a step function and decorate it with `smp.step`.

1. Perform post-processing on the outputs across microbatches using [https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#StepOutput](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#StepOutput) methods such as `reduce_mean`. The [https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#smp.init](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#smp.init) function must have a return value that depends on the output of `smp.DistributedModel`.

1. If there is an evaluation step, similarly place the forward logic inside an `smp.step`-decorated function and post-process the outputs using [`StepOutput` API](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#StepOutput).

To learn more about the SageMaker's model parallelism library API, refer to the [API documentation](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smd_model_parallel.html). 

The following Python script is an example of a training script after the changes are made.

```
import tensorflow as tf

# smdistributed: Import TF2.x API
import smdistributed.modelparallel.tensorflow as smp

# smdistributed: Initialize
smp.init()

# Download and load MNIST dataset.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
    "MNIST-data-%d" % smp.rank()
)
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

# smdistributed: If needed, seed the shuffle with smp.dp_rank(), and drop_remainder
# in batching to make sure batch size is always divisible by number of microbatches
train_ds = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(10000, seed=smp.dp_rank())
    .batch(256, drop_remainder=True)
)

# smdistributed: Define smp.DistributedModel the same way as Keras sub-classing API 
class MyModel(smp.DistributedModel):
    def __init__(self):
        super(MyModel, self).__init__()
        # define layers

    def call(self, x, training=None):
        # define forward pass and return the model output

model = MyModel()

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="train_accuracy")

# smdistributed: Define smp.step. Return any tensors needed outside
@smp.step
def get_grads(images, labels):
    predictions = model(images, training=True)
    loss = loss_object(labels, predictions)

    grads = optimizer.get_gradients(loss, model.trainable_variables)
    return grads, loss, predictions


@tf.function
def train_step(images, labels):
    gradients, loss, predictions = get_grads(images, labels)

    # smdistributed: Accumulate the gradients across microbatches
    gradients = [g.accumulate() for g in gradients]
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # smdistributed: Merge predictions and average losses across microbatches
    train_accuracy(labels, predictions.merge())
    return loss.reduce_mean()


for epoch in range(5):
    # Reset the metrics at the start of the next epoch
    train_accuracy.reset_states()
    for images, labels in train_ds:
        loss = train_step(images, labels)
    accuracy = train_accuracy.result()
```

If you are done preparing your training script, proceed to [Step 2: Launch a Training Job Using the SageMaker Python SDK](model-parallel-sm-sdk.md). If you want to run a hybrid model and data parallel training job, continue to the next section.

## Automated splitting with TensorFlow and Horovod for hybrid model and data parallelism
<a name="model-parallel-customize-training-script-tf-2.3"></a>

You can use the SageMaker model parallelism library with Horovod for hybrid model and data parallelism. To read more about how the library splits a model for hybrid parallelism, see [Pipeline parallelism (available for PyTorch and TensorFlow)](model-parallel-intro.md#model-parallel-intro-pp).

In this step, we focus on how to modify your training script to adapt the SageMaker model parallelism library.

To properly set up your training script to pick up the hybrid parallelism configuration that you'll set in [Step 2: Launch a Training Job Using the SageMaker Python SDK](model-parallel-sm-sdk.md), use the library's helper functions, `smp.dp_rank()` and `smp.mp_rank()`, which automatically detect the data parallel rank and model parallel rank respectively. 

To find all MPI primitives the library supports, see [MPI Basics](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#mpi-basics) in the SageMaker Python SDK documentation. 

The required changes needed in the script are:
+ Adding `hvd.allreduce`
+ Broadcasting variables after the first batch, as required by Horovod
+ Seeding shuffling and/or sharding operations in the data pipeline with `smp.dp_rank()`.

**Note**  
When you use Horovod, you must not directly call `hvd.init` in your training script. Instead, you'll have to set `"horovod"` to `True` in the SageMaker Python SDK `modelparallel` parameters in [Step 2: Launch a Training Job Using the SageMaker Python SDK](model-parallel-sm-sdk.md). This allows the library to internally initialize Horovod based on the device assignments of model partitions. Calling `hvd.init()` directly in your training script can cause problems.

**Note**  
Using the `hvd.DistributedOptimizer` API directly in your training script might result in a poor training performance and speed, because the API implicitly places the `AllReduce` operation inside `smp.step`. We recommend you to use the model parallelism library with Horovod by directly calling `hvd.allreduce` after calling `accumulate()` or `reduce_mean()` on the gradients returned from `smp.step`, as will be shown in the following example.

To learn more about the SageMaker's model parallelism library API, refer to the [API documentation](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smd_model_parallel.html).

```
import tensorflow as tf
import horovod.tensorflow as hvd

# smdistributed: Import TF2.x API 
import smdistributed.modelparallel.tensorflow as smp

# smdistributed: Initialize
smp.init()

# Download and load MNIST dataset.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
    "MNIST-data-%d" % smp.rank()
)
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

# smdistributed: Seed the shuffle with smp.dp_rank(), and drop_remainder
# in batching to make sure batch size is always divisible by number of microbatches
train_ds = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(10000, seed=smp.dp_rank())
    .batch(256, drop_remainder=True)
)

# smdistributed: Define smp.DistributedModel the same way as Keras sub-classing API 
class MyModel(smp.DistributedModel):
    def __init__(self):
        super(MyModel, self).__init__()
        # define layers

    def call(self, x, training=None):
        # define forward pass and return model outputs


model = MyModel()

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="train_accuracy")

# smdistributed: Define smp.step. Return any tensors needed outside
@smp.step
def get_grads(images, labels):
    predictions = model(images, training=True)
    loss = loss_object(labels, predictions)

    grads = optimizer.get_gradients(loss, model.trainable_variables)
    return grads, loss, predictions


@tf.function
def train_step(images, labels, first_batch):
    gradients, loss, predictions = get_grads(images, labels)

    # smdistributed: Accumulate the gradients across microbatches
    # Horovod: AllReduce the accumulated gradients
    gradients = [hvd.allreduce(g.accumulate()) for g in gradients]
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # Horovod: Broadcast the variables after first batch 
    if first_batch:
        hvd.broadcast_variables(model.variables, root_rank=0)
        hvd.broadcast_variables(optimizer.variables(), root_rank=0)

    # smdistributed: Merge predictions across microbatches
    train_accuracy(labels, predictions.merge())
    return loss.reduce_mean()


for epoch in range(5):
    # Reset the metrics at the start of the next epoch
    train_accuracy.reset_states()

    for batch, (images, labels) in enumerate(train_ds):
        loss = train_step(images, labels, tf.constant(batch == 0))
```

## Manual splitting with TensorFlow
<a name="model-parallel-customize-training-script-tf-manual"></a>

Use `smp.partition` context managers to place operations in specific partition. Any operation not placed in any `smp.partition` contexts is placed in the `default_partition`. To learn more about the SageMaker's model parallelism library API, refer to the [API documentation](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smd_model_parallel.html). 

```
import tensorflow as tf

# smdistributed: Import TF2.x API.
import smdistributed.modelparallel.tensorflow as smp

# smdistributed: Initialize
smp.init()

# Download and load MNIST dataset.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
    "MNIST-data-%d" % smp.rank()
)
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

# smdistributed: If needed, seed the shuffle with smp.dp_rank(), and drop_remainder
# in batching to make sure batch size is always divisible by number of microbatches.
train_ds = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(10000, seed=smp.dp_rank())
    .batch(256, drop_remainder=True)
)

# smdistributed: Define smp.DistributedModel the same way as Keras sub-classing API.
class MyModel(smp.DistributedModel):
    def __init__(self):
         # define layers

    def call(self, x):
        with smp.partition(0):
            x = self.layer0(x)
        with smp.partition(1):
            return self.layer1(x)


model = MyModel()

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="train_accuracy")

# smdistributed: Define smp.step. Return any tensors needed outside
@smp.step
def get_grads(images, labels):
    predictions = model(images, training=True)
    loss = loss_object(labels, predictions)

    grads = optimizer.get_gradients(loss, model.trainable_variables)
    return grads, loss, predictions


@tf.function
def train_step(images, labels):
    gradients, loss, predictions = get_grads(images, labels)

    # smdistributed: Accumulate the gradients across microbatches
    gradients = [g.accumulate() for g in gradients]
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # smdistributed: Merge predictions and average losses across microbatches
    train_accuracy(labels, predictions.merge())
    return loss.reduce_mean()


for epoch in range(5):
    # Reset the metrics at the start of the next epoch
    train_accuracy.reset_states()
    for images, labels in train_ds:
        loss = train_step(images, labels)
    accuracy = train_accuracy.result()
```

## Unsupported framework features
<a name="model-parallel-tf-unsupported-features"></a>

The following TensorFlow features are not supported by the library:
+ `tf.GradientTape()` is currently not supported. You can use `Optimizer.get_gradients()` or `Optimizer.compute_gradients()` instead to compute gradients.
+ The `tf.train.Checkpoint.restore()` API is currently not supported. For checkpointing, use `smp.CheckpointManager` instead, which provides the same API and functionality. Note that checkpoint restores with `smp.CheckpointManager` should take place after the first step.

# Modify a PyTorch Training Script
<a name="model-parallel-customize-training-script-pt"></a>

In this section, you learn how to modify PyTorch training scripts to configure the SageMaker model parallelism library for auto-partitioning and manual partitioning.

**Note**  
To find which PyTorch versions are supported by the library, see [Supported Frameworks and AWS Regions](distributed-model-parallel-support.md).

**Tip**  
For end-to-end notebook examples that demonstrate how to use a PyTorch training script with the SageMaker model parallelism library, see [Amazon SageMaker AI model parallelism library v1 examples](distributed-model-parallel-examples.md).

Note that auto-partitioning is enabled by default. Unless otherwise specified, the following scripts use auto-partitioning. 

**Topics**
+ [Automated splitting with PyTorch](#model-parallel-customize-training-script-pt-16)
+ [Manual splitting with PyTorch](#model-parallel-customize-training-script-pt-16-hvd)
+ [Considerations](#model-parallel-pt-considerations)
+ [Unsupported framework features](#model-parallel-pt-unsupported-features)

## Automated splitting with PyTorch
<a name="model-parallel-customize-training-script-pt-16"></a>

The following training script changes are required to run a PyTorch training script with SageMaker's model parallelism library:

1. Import and initialize the library with [https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#smp.init](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#smp.init).

1. Wrap the model with [https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_pytorch.html#smp.DistributedModel](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_pytorch.html#smp.DistributedModel). Be mindful that any tensors returned from the `forward` method of the underlying `nn.Module` object will be broadcast across model-parallel devices, incurring communication overhead, so any tensors that are not needed outside the call method (such as intermediate activations) should not be returned.
**Note**  
For FP16 training, you need to use the [smdistributed.modelparallel.torch.model\$1creation()](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/latest/smd_model_parallel_pytorch.html) context manager to wrap the model. For more information, see [FP16 Training with Model Parallelism](model-parallel-extended-features-pytorch-fp16.md).

1. Wrap the optimizer with [https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_pytorch.html#smp.DistributedOptimizer](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_pytorch.html#smp.DistributedOptimizer).
**Note**  
For FP16 training, you need to set up static or dynamic loss scaling. For more information, see [FP16 Training with Model Parallelism](model-parallel-extended-features-pytorch-fp16.md).

1. Use the returned `DistributedModel` object instead of a user model.

1. Put the forward and backward logic in a step function and decorate it with [https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#smp.init](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#smp.init).

1. Restrict each process to its own device through `torch.cuda.set_device(smp.local_rank())`.

1. Move the input tensors to the GPU using the `.to()` API before the `smp.step` call (see example below).

1. Replace `torch.Tensor.backward` and `torch.autograd.backward` with `DistributedModel.backward`.

1. Perform post-processing on the outputs across microbatches using [https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#StepOutput](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#StepOutput) methods such as `reduce_mean`.

1. If there is an evaluation step, similarly place the forward logic inside an `smp.step`-decorated function and post-process the outputs using [`StepOutput` API](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_common_api.html#StepOutput).

1. Set `drop_last=True` in `DataLoader`. Alternatively, manually skip a batch in the training loop if the batch size is not divisible by the number of microbatches.

To learn more about the SageMaker's model parallelism library API, refer to the [API documentation](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smd_model_parallel.html). 

```
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchnet.dataset import SplitDataset
from torchvision import datasets

import smdistributed.modelparallel.torch as smp

class GroupedNet(nn.Module):
    def __init__(self):
        super(GroupedNet, self).__init__()
        # define layers

    def forward(self, x):
        # define forward pass and return model outputs


# smdistributed: Define smp.step. Return any tensors needed outside.
@smp.step
def train_step(model, data, target):
    output = model(data)
    loss = F.nll_loss(output, target, reduction="mean")
    model.backward(loss)
    return output, loss


def train(model, device, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # smdistributed: Move input tensors to the GPU ID used by the current process,
        # based on the set_device call.
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        # Return value, loss_mb is a StepOutput object
        _, loss_mb = train_step(model, data, target)

        # smdistributed: Average the loss across microbatches.
        loss = loss_mb.reduce_mean()

        optimizer.step()

# smdistributed: initialize the backend
smp.init()

# smdistributed: Set the device to the GPU ID used by the current process.
# Input tensors should be transferred to this device.
torch.cuda.set_device(smp.local_rank())
device = torch.device("cuda")

# smdistributed: Download only on a single process per instance.
# When this is not present, the file is corrupted by multiple processes trying
# to download and extract at the same time
dataset = datasets.MNIST("../data", train=True, download=False)

# smdistributed: Shard the dataset based on data-parallel ranks
if smp.dp_size() > 1:
    partitions_dict = {f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size())}
    dataset = SplitDataset(dataset, partitions=partitions_dict)
    dataset.select(f"{smp.dp_rank()}")

# smdistributed: Set drop_last=True to ensure that batch size is always divisible
# by the number of microbatches
train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True)

model = GroupedNet()
optimizer = optim.Adadelta(model.parameters(), lr=4.0)

# smdistributed: Use the DistributedModel container to provide the model
# to be partitioned across different ranks. For the rest of the script,
# the returned DistributedModel object should be used in place of
# the model provided for DistributedModel class instantiation.
model = smp.DistributedModel(model)
optimizer = smp.DistributedOptimizer(optimizer)

train(model, device, train_loader, optimizer)
```

## Manual splitting with PyTorch
<a name="model-parallel-customize-training-script-pt-16-hvd"></a>

Use [https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_pytorch.html#smp.DistributedOptimizer](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/v1.2.0/smd_model_parallel_pytorch.html#smp.DistributedOptimizer) context managers to place modules in specific devices. Any module not placed in any `smp.partition` contexts is placed in the `default_partition`. The `default_partition` needs to be provided if `auto_partition` is set to `False`. The modules that are created within a specific `smp.partition` context are placed on the corresponding partition.

To learn more about the SageMaker's model parallelism library API, refer to the [API documentation](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smd_model_parallel.html). 

```
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchnet.dataset import SplitDataset
from torchvision import datasets

import smdistributed.modelparallel.torch as smp

class GroupedNet(nn.Module):
    def __init__(self):
        super(GroupedNet, self).__init__()
        with smp.partition(0):
            # define child modules on device 0
        with smp.partition(1):
            # define child modules on device 1

    def forward(self, x):
        # define forward pass and return model outputs


# smdistributed: Define smp.step. Return any tensors needed outside.
@smp.step
def train_step(model, data, target):
    output = model(data)
    loss = F.nll_loss(output, target, reduction="mean")
    model.backward(loss)
    return output, loss


def train(model, device, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # smdistributed: Move input tensors to the GPU ID used by the current process,
        # based on the set_device call.
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        # Return value, loss_mb is a StepOutput object
        _, loss_mb = train_step(model, data, target)

        # smdistributed: Average the loss across microbatches.
        loss = loss_mb.reduce_mean()

        optimizer.step()

# smdistributed: initialize the backend
smp.init()

# smdistributed: Set the device to the GPU ID used by the current process.
# Input tensors should be transferred to this device.
torch.cuda.set_device(smp.local_rank())
device = torch.device("cuda")

# smdistributed: Download only on a single process per instance.
# When this is not present, the file is corrupted by multiple processes trying
# to download and extract at the same time
dataset = datasets.MNIST("../data", train=True, download=False)

# smdistributed: Shard the dataset based on data-parallel ranks
if smp.dp_size() > 1:
    partitions_dict = {f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size())}
    dataset = SplitDataset(dataset, partitions=partitions_dict)
    dataset.select(f"{smp.dp_rank()}")

# smdistributed: Set drop_last=True to ensure that batch size is always divisible
# by the number of microbatches
train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True)

model = GroupedNet()
optimizer = optim.Adadelta(model.parameters(), lr=4.0)

# smdistributed: Use the DistributedModel container to provide the model
# to be partitioned across different ranks. For the rest of the script,
# the returned DistributedModel object should be used in place of
# the model provided for DistributedModel class instantiation.
model = smp.DistributedModel(model)
optimizer = smp.DistributedOptimizer(optimizer)

train(model, device, train_loader, optimizer)
```

## Considerations
<a name="model-parallel-pt-considerations"></a>

When you configure a PyTorch training script using SageMaker's model parallelism library, you should be aware of the following:
+ If you are using an optimization technique that relies on global gradient norms, for example gradient norm from the entire model, such as some variants of LAMB optimizer or global gradient clipping, you need to gather all the norms across the model partitions for correctness. You can use the library’s communication basic data types to do this.
+ All `torch.Tensor` arguments to the forward methods of the `nn.Modules` in your model must be used in the computation of the module output. In other words, the library does not support the case where there is a `torch.Tensor` argument to a module on which the module output does not depend.
+ The argument to the `smp.DistributedModel.backward()` call must depend on all model outputs. In other words, there cannot be an output from the `smp.DistributedModel.forward` call that is not used in the computation of the tensor that is fed into the `smp.DistributedModel.backward` call.
+ If there are `torch.cuda.synchronize()` calls in your code, you might need to call `torch.cuda.set_device(smp.local_rank())` immediately before the synchronize call. Otherwise unnecessary CUDA contexts might be created in device 0, which will needlessly consume memory.
+ Since the library places `nn.Modules` on different devices, the modules in the model must not depend on any global state that is modified inside `smp.step`. Any state that remains fixed throughout training, or that is modified outside `smp.step` in a way that is visible to all processes, is allowed.
+ You don’t need to move the model to GPU (for example, using `model.to(device)`) when using the library. If you try to move the model to GPU before the model is partitioned (before the first `smp.step` call), the move call is ignored. The library automatically moves the part of the model assigned to a rank to its GPU. Once training with the library starts, don’t move the model to CPU and use it, as it won’t have correct parameters for modules not assigned to the partition held by the process. If you want to retrain a model or use it for inference without the library after it was trained using the model parallelism library, the recommended way is to save the full model using our checkpointing API and load it back to a regular PyTorch Module.
+ If you have a list of modules such that output of one feeds into another, replacing that list with `nn.Sequential` can significantly improve performance.
+ The weight update (`optimizer.step()`) needs to happen outside of `smp.step` because that is when the entire backward pass is done and gradients are ready. When using a hybrid model with model and data parallelism, at this point, AllReduce of gradients is also guaranteed to finish.
+ When using the library in combination with data parallelism, make sure that the number of batches on all data parallel ranks is the same so that AllReduce does not hang waiting for a rank which is not participating in the step.
+ If you launch a training job using an ml.p4d instance type (such as ml.p4d.24xlarge), you must set the data loader variable `num_workers=0`. For example, you may define your `DataLoader` as follows:

  ```
  dataloader = torch.utils.data.DataLoader(
              data,
              batch_size=batch_size,
              num_workers=0,
              pin_memory=True,
              drop_last=True,
              shuffle=shuffle,
          )
  ```
+ The inputs to `smp.step` must be the model inputs generated by `DataLoader`. This is because `smp.step` internally splits the input tensors along the batch dimension and pipelines them. This means that passing `DataLoader` itself to the `smp.step` function to generate the model inputs inside does not work. 

  For example, if you define a `DataLoader` as follows:

  ```
  train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True)
  ```

  You should access the model inputs generated by `train_loader` and pass those to an `smp.step` decorated function. Do not pass `train_loader` directly to the `smp.step` function.

  ```
  def train(model, device, train_loader, optimizer):
      model.train()
      for batch_idx, (data, target) in enumerate(train_loader):
          ...
          _, loss_mb = train_step(model, data, target)
          ...
  
  @smp.step
  def train_step(model, data, target):
      ...
      return output, loss
  ```
+ The input tensors to `smp.step` must be moved to the current device using `.to()` API, which must take place after the `torch.cuda.set_device(local_rank())` call.

  For example, you may define the `train` function as follows. This function adds `data` and `target` to the current device using `.to()` API before using those input tensors to call `train_step`.

  ```
  def train(model, device, train_loader, optimizer):
      model.train()
      for batch_idx, (data, target) in enumerate(train_loader):
          # smdistributed: Move input tensors to the GPU ID used by the current process,
          # based on the set_device call.
          data, target = data.to(device), target.to(device)
          optimizer.zero_grad()
          # Return value, loss_mb is a StepOutput object
          _, loss_mb = train_step(model, data, target)
  
          # smdistributed: Average the loss across microbatches.
          loss = loss_mb.reduce_mean()
  
          optimizer.step()
  ```

  The input tensors to this `smp.set` decorated function have been moved to the current device in the `train` function above. The model does *not* need to be moved to the current device. The library automatically moves the part of the model assigned to a rank to its GPU.

  ```
  @smp.step
  def train_step(model, data, target):
      output = model(data)
      loss = F.nll_loss(output, target, reduction="mean")
      model.backward(loss)
      return output, loss
  ```

## Unsupported framework features
<a name="model-parallel-pt-unsupported-features"></a>

The following PyTorch features are unsupported by SageMaker's model parallelism library:
+ If you use data parallelism with the native [PyTorch DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html), the [https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) wrapper module is not supported by the library. The library internally manages integrating with PyTorch DDP, including parameter broadcast and gradient AllReduce. When using the library, module buffers are only broadcast once at the start of training. If your model has module buffers that need to be synchronized across data parallel groups at each step, you can do so through the `torch.distributed` API, using the process group that can be obtained via `smp.get_dp_process_group()`.
+ For mixed precision training, the `apex.amp` module is not supported. The recommended way to use the library with automatic mixed-precision is to use `torch.cuda.amp`, with the exception of using `smp.amp.GradScaler` instead of the implementation in torch.
+ `torch.jit.ScriptModules` or `ScriptFunctions` are not supported by `smp.DistributedModel`.
+ `apex` : `FusedLayerNorm`, `FusedAdam`, `FusedLAMB`, and `FusedNovoGrad` from `apex` are not supported. You can use the library implementations of these through `smp.optimizers` and `smp.nn` APIs instead.