Checkpointing menggunakan SMP - Amazon SageMaker AI

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

Checkpointing menggunakan SMP

Pustaka SageMaker model paralelisme (SMP) mendukung pos pemeriksaan, dan PyTorch APIs menyediakan pos pemeriksaan bantuan dengan benar saat APIs menggunakan pustaka SMP.

PyTorch FSDP (Fully Sharded Data Parallelism) mendukung tiga jenis pos pemeriksaan: penuh, sharded, dan lokal, masing-masing melayani tujuan yang berbeda. Pos pemeriksaan penuh digunakan saat mengekspor model setelah pelatihan selesai, karena menghasilkan pos pemeriksaan penuh adalah proses yang mahal secara komputasi. Pos pemeriksaan sharded membantu menyimpan dan memuat status model yang dipecah untuk setiap peringkat individu. Dengan pos pemeriksaan sharded, Anda dapat melanjutkan pelatihan dengan konfigurasi perangkat keras yang berbeda, seperti jumlah yang berbeda. GPUs Namun, memuat pos pemeriksaan sharded bisa lambat karena komunikasi yang terlibat di antara beberapa perangkat. Pustaka SMP menyediakan fungsionalitas pos pemeriksaan lokal, yang memungkinkan pengambilan status model lebih cepat tanpa overhead komunikasi tambahan. Perhatikan bahwa pos pemeriksaan yang dibuat oleh FSDP memerlukan penulisan ke sistem file jaringan bersama seperti Amazon. FSx

Pos pemeriksaan lokal async

Saat melatih model pembelajaran mesin, tidak perlu iterasi berikutnya untuk menunggu file pos pemeriksaan disimpan ke disk. Dengan dirilisnya SMP v2.5, perpustakaan mendukung penyimpanan file pos pemeriksaan secara asinkron. Ini berarti bahwa iterasi pelatihan berikutnya dapat berjalan bersamaan dengan operasi input dan output (I/O) untuk membuat pos pemeriksaan, tanpa diperlambat atau ditahan oleh operasi I/O tersebut. Selain itu, proses pengambilan model sharded dan paramemeter pengoptimal PyTorch dapat memakan waktu karena komunikasi kolektif tambahan yang diperlukan untuk menukar metadata tensor terdistribusi di seluruh peringkat. Bahkan ketika menggunakan StateDictType.LOCAL_STATE_DICT untuk menyimpan pos pemeriksaan lokal untuk setiap peringkat, PyTorch masih memanggil kait yang melakukan komunikasi kolektif. Untuk mengurangi masalah ini dan mengurangi waktu yang diperlukan untuk pengambilan pos pemeriksaan, SMP memperkenalkanSMStateDictType.SM_LOCAL_STATE_DICT, yang memungkinkan pengambilan lebih cepat dari pos pemeriksaan model dan pengoptimal dengan melewati overhead komunikasi kolektif.

catatan

Menjaga konsistensi dalam FSDP SHARD_DEGREE adalah persyaratan untuk memanfaatkan. SMStateDictType.SM_LOCAL_STATE_DICT Pastikan bahwa SHARD_DEGREE sisa-sisa tidak berubah. Sementara jumlah replikasi model dapat bervariasi, tingkat pecahan model harus identik dengan pengaturan pelatihan sebelumnya saat melanjutkan dari pos pemeriksaan.

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

Cuplikan kode berikut menunjukkan cara memuat pos pemeriksaan menggunakan. SMStateDictType.SM_LOCAL_STATE_DICT

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

Menyimpan pos pemeriksaan untuk model bahasa besar (LLMs) bisa mahal karena sering membutuhkan pembuatan volume sistem file yang besar. Untuk mengurangi biaya, Anda memiliki opsi untuk menyimpan pos pemeriksaan langsung ke Amazon S3 tanpa perlu layanan sistem file tambahan seperti Amazon. FSx Anda dapat memanfaatkan contoh sebelumnya dengan cuplikan kode berikut untuk menyimpan pos pemeriksaan ke S3 dengan menentukan URL S3 sebagai tujuan.

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

Pos pemeriksaan sharded async

Mungkin ada situasi di mana Anda perlu melanjutkan pelatihan dengan konfigurasi perangkat keras yang berbeda, seperti mengubah jumlah. GPUs Dalam kasus ini, proses pelatihan Anda harus memuat pos pemeriksaan saat resharding, yang berarti melanjutkan pelatihan berikutnya dengan jumlah yang berbeda. SHARD_DEGREE Untuk mengatasi skenario di mana Anda perlu melanjutkan pelatihan dengan jumlah yang berbedaSHARD_DEGREE, Anda harus menyimpan pos pemeriksaan model Anda menggunakan jenis kamus status sharded, yang diwakili oleh. StateDictType.SHARDED_STATE_DICT Menyimpan pos pemeriksaan dalam format ini memungkinkan Anda menangani proses resharding dengan benar saat melanjutkan pelatihan dengan konfigurasi perangkat keras yang dimodifikasi. Cuplikan kode yang disediakan menggambarkan cara menggunakan tsm API untuk menyimpan pos pemeriksaan sharded secara asinkron, memungkinkan proses pelatihan yang lebih efisien dan efisien.

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

Proses memuat pos pemeriksaan bersama mirip dengan bagian sebelumnya, tetapi melibatkan penggunaan torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader dan load metodenya. loadMetode kelas ini memungkinkan Anda untuk memuat data pos pemeriksaan bersama, mengikuti proses analog dengan yang dijelaskan sebelumnya.

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

Pos pemeriksaan model lengkap

Di akhir pelatihan, Anda dapat menyimpan pos pemeriksaan lengkap yang menggabungkan semua pecahan model ke dalam satu file pos pemeriksaan model. Pustaka SMP sepenuhnya mendukung API pos pemeriksaan model PyTorch lengkap, jadi Anda tidak perlu melakukan perubahan apa pun.

Perhatikan bahwa jika Anda menggunakan SMPParalelisme tensor, perpustakaan SMP mengubah model. Saat memeriksa model lengkap dalam kasus ini, pustaka SMP menerjemahkan model kembali ke format pos pemeriksaan Hugging Face Transformers secara default.

Jika Anda berlatih dengan paralelisme tensor SMP dan mematikan proses penerjemahan SMP, Anda dapat menggunakan translate_on_save argumen PyTorch FullStateDictConfig API untuk mengaktifkan atau menonaktifkan terjemahan otomatis SMP sesuai kebutuhan. Misalnya, jika Anda berfokus pada pelatihan model, Anda tidak perlu menambahkan proses terjemahan yang menambahkan overhead. Dalam hal ini, kami sarankan Anda untuk mengaturtranslate_on_save=False. Juga, jika Anda berencana untuk tetap menggunakan terjemahan SMP model untuk pelatihan lebih lanjut di masa depan, Anda dapat mematikannya untuk menyimpan terjemahan SMP model untuk digunakan nanti. Menerjemahkan model kembali ke format pos pemeriksaan model Hugging Face Transformers diperlukan saat Anda menyelesaikan pelatihan model Anda dan menggunakannya untuk inferensi.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

Perhatikan bahwa pilihannya FullStateDictConfig(rank0_only=True, offload_to_cpu=True) adalah mengumpulkan model pada CPU perangkat peringkat 0 untuk menghemat memori saat melatih model besar.

Untuk memuat kembali model untuk inferensi, Anda melakukannya seperti yang ditunjukkan pada contoh kode berikut. Perhatikan bahwa kelas AutoModelForCausalLM mungkin berubah ke kelas pembuat faktor lain di Hugging Face Transformers, AutoModelForSeq2SeqLM seperti, tergantung pada model Anda. Untuk informasi selengkapnya, lihat dokumentasi Hugging Face Transformers.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)