使用 的檢查點 SMP - Amazon SageMaker AI

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

使用 的檢查點 SMP

SageMaker 模型平行處理 (SMP) 程式庫支援 PyTorch APIs檢查點,並提供APIs使用程式SMP庫時協助正確檢查點的功能。

PyTorch FSDP (完整碎片資料平行處理) 支援三種檢查點類型:完整、碎片和本機,每個檢查點都提供不同的用途。訓練完成後匯出模型時會使用完整檢查點,因為產生完整檢查點是運算成本高昂的程序。碎片檢查點有助於儲存和載入針對每個個別排名碎片的模型狀態。使用碎片檢查點,您可以使用不同的硬體組態繼續訓練,例如不同數量的 GPUs。不過,載入碎片檢查點可能會很慢,因為多個裝置之間涉及通訊。SMP 程式庫提供本機檢查點功能,可更快速擷取模型的狀態,無需額外的通訊開銷。請注意, 建立的檢查點FSDP需要寫入共用網路檔案系統,例如 Amazon FSx。

非同步本機檢查點

訓練機器學習模型時,不需要等待檢查點檔案儲存至磁碟的後續反覆運算。隨著 SMP v2.5 的發行,程式庫支援以非同步方式儲存檢查點檔案。這表示後續的訓練反覆運算可以與輸入和輸出 (I/O) 操作同時執行,以建立檢查點,而不會受到這些 I/O 操作的減慢或延遲。此外,擷取 中的碎片模型和最佳化工具參數的程序 PyTorch 可能很耗時,因為在各排名之間交換分散式張量中繼資料所需的額外集體通訊。即使使用 StateDictType.LOCAL_STATE_DICT為每個排名儲存本機檢查點, PyTorch 仍會叫用執行集體通訊的勾點。為了緩解此問題並縮短檢查點擷取所需的時間, SMP引進 SMStateDictType.SM_LOCAL_STATE_DICT,透過略過集體通訊開銷,可更快速擷取模型和最佳化工具檢查點。

注意

在 中保持一致性FSDPSHARD_DEGREE是使用 的必要條件SMStateDictType.SM_LOCAL_STATE_DICT。確保 SHARD_DEGREE 保持不變。雖然模型複寫的數量可能不同,但從檢查點繼續時,模型碎片程度必須與先前的訓練設定相同。

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

下列程式碼片段示範如何使用 載入檢查點SMStateDictType.SM_LOCAL_STATE_DICT

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

儲存大型語言模型的檢查點 (LLMs) 可能很昂貴,因為它通常需要建立大型檔案系統磁碟區。若要降低成本,您可以選擇將檢查點直接儲存到 Amazon S3,而不需要額外的檔案系統服務,例如 Amazon FSx。您可以使用下列程式碼片段利用先前的範例,將 S3 指定URL為目的地,將檢查點儲存至 S3。

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

非同步碎片檢查點

在某些情況下,您可能需要繼續使用不同的硬體組態進行訓練,例如變更 的數量GPUs。在這些情況下,您的訓練程序必須在重新分片時載入檢查點,這表示使用不同數量的 繼續進行後續訓練SHARD_DEGREE。為了解決您需要使用不同數量的 繼續進行訓練的情況SHARD_DEGREE,您必須使用碎片狀態字典類型來儲存模型檢查點,該類型由 表示StateDictType.SHARDED_STATE_DICT。使用此格式儲存檢查點可讓您在繼續使用修改後的硬體組態進行訓練時,正確處理重新分片程序。提供的程式碼片段說明如何使用 以tsmAPI非同步方式儲存碎片檢查點,從而實現更有效率且簡化的訓練程序。

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

載入共用檢查點的程序與上一節類似,但涉及使用 torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader及其load方法。此類別的 load方法可讓您載入共用檢查點資料,遵循類似先前所述程序的程序。

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

完整模型檢查點

在訓練結束時,您可以儲存完整的檢查點,將模型的所有碎片合併為單一模型檢查點檔案。SMP 程式庫完全支援 PyTorch 完整的模型檢查點 API,因此您不需要進行任何變更。

請注意,如果您使用 SMP 張量平行處理,程式SMP庫會轉換模型。在這種情況下,檢查點完整模型時,SMP程式庫預設會將模型轉譯回 Hugging Face Transformer 檢查點格式。

如果您使用SMP張量平行處理進行訓練並關閉SMP轉譯程序,您可以使用 的translate_on_save引數視需要 PyTorch FullStateDictConfigAPI開啟或關閉SMP自動轉譯。例如,如果您專注於訓練模型,則不需要新增會增加額外負荷的翻譯程序。在這種情況下,我們建議您設定 translate_on_save=False。此外,如果您打算繼續使用模型的SMP翻譯來進一步訓練,您可以將其關閉,以儲存模型的SMP翻譯以供日後使用。當您結束模型的訓練並使用模型進行推論時,需要將模型轉返 Hugging Face 轉換器模型檢查點格式。

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

請注意, 選項FullStateDictConfig(rank0_only=True, offload_to_cpu=True)是在第 CPU 0 排名裝置的 上收集模型,以便在訓練大型模型時節省記憶體。

若要重新載入模型以進行推論,請執行此操作,如下列程式碼範例所示。請注意, 類別AutoModelForCausalLM可能會變更為 Hugging Face Transformer 中的其他因素建置器類別,例如 AutoModelForSeq2SeqLM,視您的模型而定。如需詳細資訊,請參閱 Hugging Face Transformer 文件

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)