를 사용한 체크포인트 지정 SMP - Amazon SageMaker

기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.

를 사용한 체크포인트 지정 SMP

SageMaker 모델 병렬 처리(SMP) 라이브러리는 체크포인트를 지원 PyTorch APIs하며 SMP 라이브러리를 사용하는 동안 체크포인트APIs를 올바르게 지원하는 를 제공합니다.

PyTorch FSDP (완전하게 샤딩된 데이터 병렬 처리)는 각각 서로 다른 목적을 제공하는 전체, 샤딩된 및 로컬의 세 가지 유형의 체크포인트를 지원합니다. 전체 체크포인트 생성은 계산 비용이 많이 드는 프로세스이므로 전체 체크포인트는 훈련이 완료된 후 모델을 내보낼 때 사용됩니다. 샤딩된 체크포인트는 각 개별 순위에 대해 샤딩된 모델의 상태를 저장하고 로드하는 데 도움이 됩니다. 샤딩된 체크포인트를 사용하면 다른 수의 와 같은 다양한 하드웨어 구성으로 훈련을 재개할 수 있습니다GPUs. 그러나 여러 디바이스 간의 통신으로 인해 샤딩된 체크포인트 로드가 느려질 수 있습니다. SMP 라이브러리는 추가 통신 오버헤드 없이 모델 상태를 더 빠르게 검색할 수 있는 로컬 체크포인트 기능을 제공합니다. 에서 생성한 체크포인트는 Amazon 과 같은 공유 네트워크 파일 시스템에 써FSDP야 합니다FSx.

로컬 체크포인트 비동기화

기계 학습 모델을 훈련할 때 체크포인트 파일이 디스크에 저장될 때까지 기다리지 않아도 됩니다. SMP v2.5 릴리스와 함께 라이브러리는 체크포인트 파일 비동기 저장을 지원합니다. 즉, 후속 훈련 반복은 I/O 작업으로 인해 속도가 느려지거나 지연되지 않고 체크포인트를 생성하기 위한 입력 및 출력(I/O) 작업과 동시에 실행될 수 있습니다. 또한 에서 샤딩된 모델 및 옵티마이저 파라미터 검색 프로세스는 여러 순위에서 분산 텐서 메타데이터를 교환하는 데 필요한 추가 집합 통신으로 인해 시간이 많이 걸릴 PyTorch 수 있습니다. StateDictType.LOCAL_STATE_DICT 를 사용하여 각 순위에 대한 로컬 체크포인트를 저장할 때도 PyTorch 는 집합 통신을 수행하는 후크를 호출합니다. 이 문제를 완화하고 체크포인트 검색에 필요한 시간을 줄이기 위해 SMStateDictType.SM_LOCAL_STATE_DICT는 집합 통신 오버헤드를 우회하여 모델 및 옵티마이저 체크포인트를 더 빠르게 검색할 수 있는 를 SMP 도입했습니다.

참고

에서 일관성을 유지하는 FSDP SHARD_DEGREE 것은 를 활용하기 위한 요구 사항입니다SMStateDictType.SM_LOCAL_STATE_DICT. 가 변경되지 않은 상태로 SHARD_DEGREE 유지되는지 확인합니다. 모델 복제 수는 다를 수 있지만 체크포인트에서 재개할 때는 모델 샤드 정도가 이전 훈련 설정과 동일해야 합니다.

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

다음 코드 조각은 를 사용하여 체크포인트를 로드하는 방법을 보여줍니다SMStateDictType.SM_LOCAL_STATE_DICT.

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

대규모 언어 모델(LLMs)에 대한 체크포인트를 저장하려면 대용량 파일 시스템 볼륨을 생성해야 하는 경우가 많기 때문에 비용이 많이 들 수 있습니다. 비용을 절감하기 위해 Amazon 와 같은 추가 파일 시스템 서비스 없이도 체크포인트를 Amazon S3에 직접 저장할 수 있습니다FSx. 다음 코드 조각과 함께 이전 예제를 활용하여 S3를 대상으로 지정하여 S3에 체크포인트를 저장할 수 URL 있습니다.

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

비동기 샤딩된 체크포인트

의 수를 변경하는 등 다양한 하드웨어 구성으로 계속 훈련해야 하는 상황이 있을 수 있습니다GPUs. 이러한 경우 훈련 프로세스는 리샤딩 중에 체크포인트를 로드해야 합니다. 즉, 다른 수의 로 후속 훈련을 재개해야 합니다SHARD_DEGREE. 다른 수의 로 훈련을 재개해야 하는 시나리오를 해결하려면 로 표시되는 샤딩된 상태 사전 유형을 사용하여 모델 체크포인트를 저장해야 SHARD_DEGREE합니다StateDictType.SHARDED_STATE_DICT. 이 형식으로 체크포인트를 저장하면 수정된 하드웨어 구성으로 훈련을 계속할 때 재분배 프로세스를 올바르게 처리할 수 있습니다. 제공된 코드 조각은 를 사용하여 샤딩된 체크포인트를 비동기적으로 저장tsmAPI하여 보다 효율적이고 간소화된 훈련 프로세스를 가능하게 하는 방법을 보여줍니다.

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

공유 체크포인트를 로드하는 프로세스는 이전 섹션과 유사하지만 torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader 및 해당 load 방법을 사용해야 합니다. 이 클래스의 load 방법을 사용하면 앞서 설명한 것과 유사한 프로세스에 따라 공유 체크포인트 데이터를 로드할 수 있습니다.

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

전체 모델 체크포인트

훈련이 끝나면 모델의 모든 샤드를 단일 모델 체크포인트 파일로 결합하는 전체 체크포인트를 저장할 수 있습니다. SMP 라이브러리는 PyTorch 전체 모델 체크포인트 를 완전히 지원API하므로 변경할 필요가 없습니다.

를 사용하면 SMP 텐서 병렬화 SMP 라이브러리가 모델을 변환합니다. 이 경우 전체 모델을 체크포인트할 때 SMP 라이브러리는 모델을 Hugging Face Transformers 체크포인트 형식으로 다시 변환합니다.

SMP 텐서 병렬 처리로 훈련하고 SMP 번역 프로세스를 끄는 경우 필요에 따라 의 translate_on_save 인수 PyTorch FullStateDictConfigAPI를 사용하여 SMP 자동 번역을 켜거나 끌 수 있습니다. 예를 들어 모델 훈련에 집중하는 경우 오버헤드를 추가하는 번역 프로세스를 추가할 필요가 없습니다. 이 경우 를 설정하는 것이 좋습니다translate_on_save=False. 또한 향후 추가 훈련을 위해 모델 SMP 번역을 계속 사용할 계획이라면 끌 수 있어 나중에 사용할 수 있도록 모델 SMP 번역을 저장할 수 있습니다. 모델 훈련을 마무리하고 추론에 사용할 때는 모델을 Hugging Face Transformers 모델 체크포인트 형식으로 다시 변환해야 합니다.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

옵션은 대형 모델을 훈련할 때 메모리를 절약하기 위해 0순위 디바이스CPU의 에서 모델을 수집하는 FullStateDictConfig(rank0_only=True, offload_to_cpu=True) 것입니다.

추론을 위해 모델을 다시 로드하려면 다음 코드 예제와 같이 로드합니다. 모델에 AutoModelForSeq2SeqLM따라 클래스가 Hugging Face Transformer의 다른 팩터 빌더 클래스로 변경될 AutoModelForCausalLM 수 있습니다. 자세한 내용은 Hugging Face Transformers 설명서를 참조하세요.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)