

本文属于机器翻译版本。若本译文内容与英语原文存在差异，则一律以英文原文为准。

# 调整您的 PyTorch 训练脚本
<a name="debugger-modify-script-pytorch"></a>

要开始收集模型输出张量并调试训练问题，请对 PyTorch 训练脚本进行以下修改。

**注意**  
SageMaker 调试器无法从 [https://pytorch.org/docs/stable/nn.functional.html](https://pytorch.org/docs/stable/nn.functional.html)API 操作中收集模型输出张量。在编写 PyTorch 训练脚本时，建议改用这些[https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html](https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html)模块。

## 对于 PyTorch 1.12.0
<a name="debugger-modify-script-pytorch-1-12-0"></a>

如果您带上 PyTorch 训练脚本，则可以在训练脚本中使用几行额外的代码来运行训练作业并提取模型输出张量。你需要在`sagemaker-debugger`客户端库 APIs中使用[挂钩](https://sagemaker-debugger.readthedocs.io/en/website/hook-api.html)。仔细阅读以下说明，这些说明分别介绍了各个步骤并提供代码示例。

1. 创建钩子。

   **（推荐）用于 SageMaker AI 内部的训练作业**

   ```
   import smdebug.pytorch as smd
   hook=smd.get_hook(create_if_not_exists=True)
   ```

   当您在估算器中[使用 SageMaker Python SDK 使用调试器启动训练作业](debugger-configuration-for-debugging.md)使用任何 DebuggerHookConfig TensorBoardConfig、或规则启动训练作业时， SageMaker AI 会向您的训练实例添加一个 JSON 配置文件，该文件由该`get_hook`函数获取。请注意，如果您在估算器 APIs 中不包含任何配置，则不会有配置文件可供钩子查找，并且函数会返回。`None`

   **（可选）用于在 SageMaker AI 之外训练作业**

   如果您在本地模式下直接在 SageMaker 笔记本实例、Amazon EC2 实例或您自己的本地设备上运行训练作业，请使用`smd.Hook`类来创建挂钩。但是，这种方法只能存储张量集合并可用于可 TensorBoard 视化。 SageMaker 调试器的内置规则不适用于本地模式，因为这些规则要求 SageMaker AI ML 训练实例和 S3 实时存储来自远程实例的输出。在这种情况下，`smd.get_hook` API 会返回 `None`。

   如果您要创建手动钩子以在本地模式下保存张量，请使用以下带有逻辑的代码片段检查 `smd.get_hook` API 是否返回 `None`，并使用 `smd.Hook` 类创建手动钩子。请注意，您可以指定本地计算机中的任何输出目录。

   ```
   import smdebug.pytorch as smd
   hook=smd.get_hook(create_if_not_exists=True)
   
   if hook is None:
       hook=smd.Hook(
           out_dir='/path/to/your/local/output/',
           export_tensorboard=True
       )
   ```

1. 用钩子的类方法包装您的模型。

   `hook.register_module()` 方法获取您的模型并遍历每一层，寻找与您通过 [使用 SageMaker Python SDK 使用调试器启动训练作业](debugger-configuration-for-debugging.md) 中配置提供的正则表达式匹配的任何张量。通过这种钩子方法可以收集到的张量包括权重、偏差、激活、梯度、输入和输出。

   ```
   hook.register_module(model)
   ```
**提示**  
如果您从大型深度学习模型中完整地收集输出张量，则这些集合的总大小会呈指数级增长，并可能导致瓶颈。如果您要保存特定张量，还可以使用 `hook.save_tensor()` 方法。此方法可协助您为特定张量选取变量，并保存到您自己命名的自定义集合中。有关更多信息，请参阅本说明中的[步骤 7](#debugger-modify-script-pytorch-save-custom-tensor)。

1. 用钩子的类方法包装损失函数。

   `hook.register_loss` 方法用于包装损失函数。它会根据您在 [使用 SageMaker Python SDK 使用调试器启动训练作业](debugger-configuration-for-debugging.md) 的配置中设置的 `save_interval` 来提取损失值，并将它们保存到 `"losses"` 集合中。

   ```
   hook.register_loss(loss_function)
   ```

1. 在训练块中添加 `hook.set_mode(ModeKeys.TRAIN)`。这表示张量集合是在训练阶段中提取的。

   ```
   def train():
       ...
       hook.set_mode(ModeKeys.TRAIN)
   ```

1. 在验证块中添加 `hook.set_mode(ModeKeys.EVAL)`。这表示张量集合是在验证阶段中提取的。

   ```
   def validation():
       ...
       hook.set_mode(ModeKeys.EVAL)
   ```

1. 使用 [https://sagemaker-debugger.readthedocs.io/en/website/hook-constructor.html#smdebug.core.hook.BaseHook.save_scalar](https://sagemaker-debugger.readthedocs.io/en/website/hook-constructor.html#smdebug.core.hook.BaseHook.save_scalar) 保存自定义标量。您可以保存模型中没有的标量值。例如，如果您要记录评估期间计算的准确性值，请在计算准确性的行下方添加以下代码行。

   ```
   hook.save_scalar("accuracy", accuracy)
   ```

   请注意，您需要提供字符串作为第一个参数，用于命名自定义标量集合。这个名称将用于可视化中的标量值 TensorBoard，可以是任何你想要的字符串。

1. <a name="debugger-modify-script-pytorch-save-custom-tensor"></a>使用 [https://sagemaker-debugger.readthedocs.io/en/website/hook-constructor.html#smdebug.core.hook.BaseHook.save_tensor](https://sagemaker-debugger.readthedocs.io/en/website/hook-constructor.html#smdebug.core.hook.BaseHook.save_tensor) 保存自定义张量。与 [https://sagemaker-debugger.readthedocs.io/en/website/hook-constructor.html#smdebug.core.hook.BaseHook.save_scalar](https://sagemaker-debugger.readthedocs.io/en/website/hook-constructor.html#smdebug.core.hook.BaseHook.save_scalar) 类似，您可以保存其他张量，并定义自己的张量集合。例如，您可以提取传入模型的输入映像数据，并通过添加以下代码行，将其保存为自定义张量，其中 `"images"` 是自定义张量的示例名称，`image_inputs` 是输入映像数据的示例变量。

   ```
   hook.save_tensor("images", image_inputs)
   ```

   请注意，您必须向第一个参数提供字符串来命名自定义张量。`hook.save_tensor()` 使用第三个参数 `collections_to_write` 来指定张量集合，用于保存自定义张量。默认值为 `collections_to_write="default"`。如果您没有明确指定第三个参数，则自定义张量将保存到 `"default"` 张量集合中。

调整完训练脚本后，继续到 [使用 SageMaker Python SDK 使用调试器启动训练作业](debugger-configuration-for-debugging.md)。