本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
启用检查点功能
启用检查点功能后, SageMaker 将检查点保存到 Amazon S3,并将您的训练作业与检查点 S3 存储桶同步。您可以将 S3 通用存储桶或 S3 目录存储桶用于检查点 S3 存储桶。
以下示例说明如何在构造 SageMaker 估算器时配置检查点路径。要启用检查点,将 checkpoint_s3_uri
和 checkpoint_local_path
参数添加到估算器。
以下示例模板展示了如何创建通用 SageMaker 估算器并启用检查点功能。通过指定 image_uri
参数,可以将此模板用于支持的算法。要查找支持检查点URIs的算法的 Docker 镜像 SageMaker,请参阅 Docker 注册表路径和示例代码。您也可以Estimator
用其他 SageMaker 框架的估计器父类和估计器类替换estimator
和,例如、、和。TensorFlow
PyTorch
MXNet
HuggingFace
XGBoost
import sagemaker from sagemaker.
estimator
importEstimator
bucket=sagemaker.Session().default_bucket() base_job_name="sagemaker-checkpoint-test
" checkpoint_in_bucket="checkpoints
" # The S3 URI to store the checkpoints checkpoint_s3_bucket="s3://{}/{}/{}".format(bucket, base_job_name, checkpoint_in_bucket) # The local path where the model will save its checkpoints in the training container checkpoint_local_path="/opt/ml/checkpoints" estimator =Estimator
( ... image_uri="<ecr_path>
/<algorithm-name>
:<tag>
" # Specify to use built-in algorithms output_path=bucket, base_job_name=base_job_name, # Parameters required to enable checkpointing checkpoint_s3_uri=checkpoint_s3_bucket, checkpoint_local_path=checkpoint_local_path )
以下两个参数指定检查点的路径:
-
checkpoint_local_path
– 指定模型定期在训练容器中保存检查点的本地路径。默认路径设置为'/opt/ml/checkpoints'
。如果您使用的是其他框架或自带训练容器,请确保训练脚本的检查点配置指定'/opt/ml/checkpoints'
路径。注意
我们建议指定与默认 SageMaker 检查点设置一致的本地路径。
'/opt/ml/checkpoints'
如果您更喜欢指定自己的本地路径,请确保与训练脚本中的检查点保存路径和 SageMaker估算器的checkpoint_local_path
参数相匹配。 -
checkpoint_s3_uri
— URI 到实时存储检查点的 S3 存储桶。您可以指定 S3 通用存储桶或 S3 目录存储桶来存储您的检查点。有关 S3 目录存储桶的更多信息,请参阅 A mazon 简单存储服务用户指南中的目录存储桶。
要查找 SageMaker 估算器参数的完整列表,请参阅 API Amazon Python 文档中的估算器