체크포인팅: 사용 SMP - 아마존 SageMaker

기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.

체크포인팅: 사용 SMP

SageMaker model parallel lism (SMP) 라이브러리는 체크포인트를 지원하며 라이브러리를 사용하는 PyTorch APIs 동안 체크포인트를 올바르게 사용할 수 APIs 있도록 도와줍니다. SMP

PyTorch FSDP(완전 샤딩된 데이터 병렬화) 은 세 가지 유형의 체크포인트, 즉 전체 체크포인트, 샤딩된 체크포인트, 로컬 체크포인트를 지원하며 각각 용도가 다릅니다. 전체 체크포인트를 생성하려면 계산 비용이 많이 들기 때문에 학습이 완료된 후 모델을 내보낼 때는 전체 체크포인트가 사용됩니다. 분할된 체크포인트를 사용하면 각 개별 순위에 맞게 분할된 모델의 상태를 저장하고 로드할 수 있습니다. 샤딩된 체크포인트를 사용하면 다양한 하드웨어 구성 (예: 개수 변경) 으로 학습을 재개할 수 있습니다. GPUs 하지만 여러 장치 간에 통신이 필요하기 때문에 샤딩된 체크포인트를 로드하는 속도가 느릴 수 있습니다. SMP라이브러리는 로컬 체크포인트 기능을 제공하므로 추가 통신 오버헤드 없이 모델 상태를 더 빠르게 검색할 수 있습니다. 참고로 체크포인트를 생성하려면 FSx Amazon과 같은 공유 네트워크 파일 시스템에 FSDP 작성해야 합니다.

비동기 로컬 체크포인트

머신 러닝 모델을 학습시킬 때는 체크포인트 파일이 디스크에 저장될 때까지 후속 반복 작업을 기다릴 필요가 없습니다. SMPv2.5가 출시되면서 라이브러리는 체크포인트 파일을 비동기적으로 저장하는 기능을 지원합니다. 즉, 후속 교육 반복을 해당 I/O 작업으로 인한 속도 저하나 지연 없이 체크포인트 생성을 위한 입력 및 출력 (I/O) 작업과 동시에 실행할 수 있습니다. 또한 분할된 모델 및 옵티마이저 파라미터를 가져오는 프로세스는 랭크 간에 분산된 텐서 메타데이터를 교환하는 데 추가적인 공동 통신이 필요하기 때문에 시간이 많이 걸릴 PyTorch 수 있습니다. 를 StateDictType.LOCAL_STATE_DICT 사용하여 각 랭크의 로컬 체크포인트를 저장하는 경우에도 집단 통신을 수행하는 후크를 호출합니다. PyTorch 이 문제를 완화하고 체크포인트 검색에 필요한 시간을 줄이기 위해 집단적 통신 오버헤드를 우회하여 모델 및 옵티마이저 체크포인트를 더 빠르게 검색할 수 있는 기능을 SMP SMStateDictType.SM_LOCAL_STATE_DICT 도입했습니다.

참고

를 활용하려면 일관성을 유지하는 것이 필수적입니다. FSDP SHARD_DEGREE SMStateDictType.SM_LOCAL_STATE_DICT SHARD_DEGREE변경되지 않았는지 확인하십시오. 모델 복제 횟수는 다를 수 있지만 체크포인트에서 다시 시작할 때는 모델 샤드 정도가 이전 학습 설정과 동일해야 합니다.

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

다음 코드 스니펫은 이를 활용하여 체크포인트를 로드하는 방법을 보여줍니다. SMStateDictType.SM_LOCAL_STATE_DICT

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

대형 언어 모델 (LLMs) 에 체크포인트를 저장하려면 대량의 파일 시스템 볼륨을 만들어야 하는 경우가 많기 때문에 비용이 많이 들 수 있습니다. 비용 절감을 위해 Amazon과 같은 추가 파일 시스템 서비스 없이 Amazon S3에 체크포인트를 직접 저장할 수 있는 옵션이 있습니다. FSx 다음 코드 스니펫과 함께 이전 예제를 활용하여 S3를 대상으로 지정하여 S3에 체크포인트를 저장할 수 있습니다. URL

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

비동기 샤딩 체크포인트

개수 변경과 같이 다양한 하드웨어 구성을 사용하여 계속 학습해야 하는 상황이 있을 수 있습니다. GPUs 이러한 경우 리샤딩하는 동안 학습 프로세스가 체크포인트를 로드해야 합니다. 즉, 다른 개수로 후속 교육을 재개해야 합니다. SHARD_DEGREE 다른 개수로 학습을 재개해야 하는 시나리오를 해결하려면 샤딩된 상태 사전 유형 (으로 표시됨) 을 사용하여 모델 체크포인트를 저장해야 합니다. SHARD_DEGREE StateDictType.SHARDED_STATE_DICT 이 형식으로 체크포인트를 저장하면 수정된 하드웨어 구성으로 학습을 계속할 때 리샤딩 프로세스를 적절하게 처리할 수 있습니다. 제공된 코드 스니펫은 를 사용하여 샤딩된 체크포인트를 비동기적으로 저장하는 방법을 보여 주므로 교육 프로세스를 더욱 효율적이고 간소화할 수 있습니다. tsm API

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

공유 체크포인트를 로드하는 프로세스는 이전 섹션과 비슷하지만 와 해당 방법을 사용해야 합니다. torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader load 이 클래스의 load 메서드를 사용하면 앞에서 설명한 것과 유사한 프로세스에 따라 공유 체크포인트 데이터를 로드할 수 있습니다.

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

전체 모델 체크포인트

학습이 끝나면 모델의 모든 샤드를 단일 모델 체크포인트 파일로 결합하는 전체 체크포인트를 저장할 수 있습니다. SMP라이브러리는 PyTorch 전체 모델 체크포인트를 완벽하게 API 지원하므로 변경할 필요가 없습니다.

참고로 를 SMP 텐서 병렬화 사용하면 SMP 라이브러리가 모델을 변환합니다. 이 경우 전체 모델을 체크포인팅할 때 SMP 라이브러리는 기본적으로 모델을 Hugging Face Transformer 체크포인트 형식으로 다시 변환합니다.

SMP텐서 병렬 처리를 사용하여 학습하고 SMP 변환 프로세스를 끄는 경우 필요에 따라 의 translate_on_save 인수를 사용하여 SMP 자동 변환을 켜거나 끌 수 있습니다. PyTorch FullStateDictConfig API 예를 들어 모델 학습에 집중하고 있다면 오버헤드를 가중시키는 번역 프로세스를 추가할 필요가 없습니다. 이 경우 설정하는 translate_on_save=False 것이 좋습니다. 또한 향후 추가 교육을 위해 모델 SMP 번역을 계속 사용할 계획이라면 나중에 사용할 수 있도록 모델 번역을 끄고 모델 SMP 번역을 저장할 수 있습니다. 모델 학습을 마무리하고 이를 추론에 사용할 때는 모델을 Hugging Face Transformer 모델 체크포인트 형식으로 다시 변환해야 합니다.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

참고로, 대형 모델을 FullStateDictConfig(rank0_only=True, offload_to_cpu=True) 학습시킬 때는 0순위 장치에 모델을 수집하여 메모리를 절약하는 옵션이 있습니다. CPU

추론을 위해 모델을 다시 불러오려면 다음 코드 예제와 같이 하십시오. 모델에 따라 Hugging Face Transformer의 다른 팩터 빌더 클래스로 클래스가 AutoModelForCausalLM 변경될 수 있다는 점에 유의하세요. AutoModelForSeq2SeqLM 자세한 내용은 Hugging Face Transformer 설명서를 참조하십시오.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)