使用 進行檢查點 SMP - Amazon SageMaker

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

使用 進行檢查點 SMP

SageMaker 模型平行處理 (SMP) 程式庫支援 PyTorch APIs檢查點,並在使用SMP程式庫時APIs提供適當的協助檢查點。

PyTorch FSDP (完整碎片資料平行處理) 支援三種類型的檢查點:完整、碎片和本機,每個檢查點都具有不同的用途。訓練完成後匯出模型時會使用完整檢查點,因為產生完整檢查點是運算昂貴的程序。碎片檢查點有助於儲存和載入針對每個個別層級碎片的模型狀態。使用碎片檢查點,您可以使用不同的硬體組態繼續訓練,例如不同的 數目GPUs。不過,由於多個裝置之間涉及通訊,因此載入碎片檢查點可能會很慢。SMP 程式庫提供本機檢查點功能,可更快速擷取模型的狀態,無需額外的通訊額外負荷。請注意,由 建立的檢查點FSDP需要寫入共用網路檔案系統,例如 Amazon FSx。

非同步本機檢查點

訓練機器學習模型時,不需要等待檢查點檔案儲存至磁碟的後續反覆運算。隨著 SMP v2.5 的發行,程式庫支援非同步儲存檢查點檔案。這表示後續的訓練反覆運算可與輸入和輸出 (I/O) 操作同時執行,以建立檢查點,而不會被這些 I/O 操作減慢或扣留。此外,擷取碎片模型和最佳化工具參數的程序 PyTorch 可能很耗時,因為在各排名之間交換分散式張量中繼資料所需的額外集體通訊。即使使用 StateDictType.LOCAL_STATE_DICT 為每個等級儲存本機檢查點, PyTorch 仍會叫用執行集體通訊的掛鉤。為了減輕此問題並縮短檢查點擷取所需的時間, SMP引入 SMStateDictType.SM_LOCAL_STATE_DICT,這允許透過繞過集體通訊額外負荷來更快擷取模型和最佳化工具檢查點。

注意

在 中保持一致性FSDPSHARD_DEGREE是使用 的必要條件SMStateDictType.SM_LOCAL_STATE_DICT。確保 SHARD_DEGREE保持不變。雖然模型複寫的數量可能不同,但從檢查點恢復時,模型碎片程度必須與先前的訓練設定相同。

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

下列程式碼片段示範如何使用 載入檢查點SMStateDictType.SM_LOCAL_STATE_DICT

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

儲存大型語言模型 (LLMs) 的檢查點可能很昂貴,因為它通常需要建立大型檔案系統磁碟區。若要降低成本,您可以選擇將檢查點直接儲存至 Amazon S3,而不需要額外的檔案系統服務,例如 Amazon FSx。您可以利用上一個範例搭配下列程式碼片段,將 S3 指定URL為目的地,將檢查點儲存至 S3。

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

非同步碎片檢查點

在某些情況下,您可能需要繼續使用不同的硬體組態進行訓練,例如變更 的數量GPUs。在這些情況下,您的訓練程序必須在重新共用時載入檢查點,這表示使用不同數量的 繼續進行後續訓練SHARD_DEGREE。為了解決您需要使用不同數量的 繼續進行訓練的情況SHARD_DEGREE,您必須使用碎片狀態字典類型來儲存模型檢查點,該類型由 表示StateDictType.SHARDED_STATE_DICT。以此格式儲存檢查點可讓您在繼續使用修改後的硬體組態訓練時,正確處理重新分配程序。提供的程式碼片段說明如何使用 以tsmAPI非同步方式儲存碎片檢查點,從而實現更有效率且更簡化的訓練程序。

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

載入共用檢查點的程序與上一節類似,但涉及使用 torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader 及其load方法。此類別load的方法可讓您根據類似上述的程序載入共用檢查點資料。

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

完整模型檢查點

在訓練結束時,您可以儲存完整的檢查點,將模型的所有碎片合併為單一模型檢查點檔案。SMP 程式庫完全支援 PyTorch 完整的模型檢查點 API,因此您不需要進行任何變更。

請注意,如果您使用 SMP 張量平行處理,則SMP程式庫會轉換模型。在這種情況下檢查完整模型時,SMP程式庫預設會將模型轉譯回 Hugging Face Transformer 檢查點格式。

如果您使用SMP張量平行處理訓練並關閉SMP翻譯程序,則可以使用 的translate_on_save引數視需要 PyTorch FullStateDictConfigAPI開啟或關閉SMP自動翻譯。例如,如果您專注於訓練模型,則不需要新增會增加額外負荷的翻譯程序。在這種情況下,我們建議您設定 translate_on_save=False。此外,如果您打算繼續使用模型的SMP翻譯來進一步訓練,您可以將其關閉,以儲存模型的SMP翻譯以供日後使用。當您結束模型訓練並使用模型進行推論時,需要將模型轉換回 Hugging Face Transformer 模型檢查點格式。

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

請注意, 選項FullStateDictConfig(rank0_only=True, offload_to_cpu=True)是在 第 0 階裝置的 CPU 上收集模型,以便在訓練大型模型時節省記憶體。

若要重新載入模型以進行推論,您可以執行此操作,如下列程式碼範例所示。請注意, 類別AutoModelForCausalLM可能會變更為 Hugging Face Transformer 中的其他因素建置器類別,例如 AutoModelForSeq2SeqLM,視您的模型而定。如需詳細資訊,請參閱 Hugging Face Transformer 文件

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)