启用检查点功能 - Amazon SageMaker

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

启用检查点功能

启用检查点功能后, SageMaker 将检查点保存到 Amazon S3,并将您的训练作业与检查点 S3 存储桶同步。您可以将 S3 通用存储桶或 S3 目录存储桶用于检查点 S3 存储桶。

训练期间写入检查点的架构图。

以下示例说明如何在构造 SageMaker 估算器时配置检查点路径。要启用检查点,将 checkpoint_s3_uricheckpoint_local_path 参数添加到估算器。

以下示例模板展示了如何创建通用 SageMaker 估算器并启用检查点功能。通过指定 image_uri 参数,可以将此模板用于支持的算法。要查找支持检查点URIs的算法的 Docker 镜像 SageMaker,请参阅 Docker 注册表路径和示例代码。您也可以Estimator用其他 SageMaker 框架的估计器父类和估计器类替换estimator和,例如、、和。TensorFlow PyTorch MXNet HuggingFace XGBoost

import sagemaker from sagemaker.estimator import Estimator bucket=sagemaker.Session().default_bucket() base_job_name="sagemaker-checkpoint-test" checkpoint_in_bucket="checkpoints" # The S3 URI to store the checkpoints checkpoint_s3_bucket="s3://{}/{}/{}".format(bucket, base_job_name, checkpoint_in_bucket) # The local path where the model will save its checkpoints in the training container checkpoint_local_path="/opt/ml/checkpoints" estimator = Estimator( ... image_uri="<ecr_path>/<algorithm-name>:<tag>" # Specify to use built-in algorithms output_path=bucket, base_job_name=base_job_name, # Parameters required to enable checkpointing checkpoint_s3_uri=checkpoint_s3_bucket, checkpoint_local_path=checkpoint_local_path )

以下两个参数指定检查点的路径:

  • checkpoint_local_path – 指定模型定期在训练容器中保存检查点的本地路径。默认路径设置为 '/opt/ml/checkpoints'。如果您使用的是其他框架或自带训练容器,请确保训练脚本的检查点配置指定 '/opt/ml/checkpoints' 路径。

    注意

    我们建议指定与默认 SageMaker 检查点设置一致的本地路径。'/opt/ml/checkpoints'如果您更喜欢指定自己的本地路径,请确保与训练脚本中的检查点保存路径和 SageMaker估算器的checkpoint_local_path参数相匹配。

  • checkpoint_s3_uri— URI 到实时存储检查点的 S3 存储桶。您可以指定 S3 通用存储桶或 S3 目录存储桶来存储您的检查点。有关 S3 目录存储桶的更多信息,请参阅 A mazon 简单存储服务用户指南中的目录存储

要查找 SageMaker 估算器参数的完整列表,请参阅 API Amazon Python 文档中的估算器。 SageMaker SDK