Inisialisasi parameter tertunda - Amazon SageMaker

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

Inisialisasi parameter tertunda

Inisialisasi model besar untuk pelatihan tidak selalu memungkinkan dengan GPU memori terbatas. Untuk mengatasi masalah GPU memori yang tidak mencukupi ini, Anda dapat menginisialisasi model pada CPU memori. Namun, untuk model yang lebih besar dengan lebih dari 20 atau 40 miliar parameter, bahkan CPU memori mungkin tidak cukup. Untuk kasus seperti itu, kami menyarankan Anda menginisialisasi model pada apa yang PyTorch disebut perangkat meta, yang memungkinkan pembuatan tensor tanpa data apa pun yang melekat padanya. Tensor pada perangkat meta hanya membutuhkan informasi bentuk, dan ini memungkinkan untuk membuat model besar dengan parameternya pada perangkat meta. Hugging Face Accelerate menyediakan init_empty_weights manajer konteks untuk membantu membuat model seperti itu pada perangkat meta sambil menginisialisasi buffer pada perangkat biasa. Sebelum pelatihan dimulai, PyTorch FSDP inisialisasi parameter model. Fitur inisialisasi parameter tertunda SMP v2 ini menunda pembuatan parameter model ini terjadi setelah PyTorch FSDP melakukan sharding parameter. PyTorch FSDPmenerima fungsi inisialisasi parameter (param_init_fn) saat sharding modul, dan memanggil param_init_fn setiap modul. Ini param_init_fn API mengambil modul sebagai argumen dan menginisialisasi semua parameter di dalamnya, tidak termasuk parameter modul anak apa pun. Perhatikan bahwa perilaku ini berbeda dari PyTorch v2.0.1 asli yang memiliki bug yang menyebabkan parameter diinisialisasi beberapa kali.

SMPv2 menyediakan torch.sagemaker.delayed_param.DelayedParamIniter API untuk menerapkan inisialisasi parameter tertunda.

Cuplikan kode berikut menunjukkan cara menerapkan skrip torch.sagemaker.delayed_param.DelayedParamIniter API pelatihan Anda.

Asumsikan bahwa Anda memiliki skrip PyTorch FSDP pelatihan sebagai berikut.

# Creation of model on meta device from accelerate import init_empty_weights with init_empty_weights(): model = create_model() # Define a param init fn, below is an example for Hugging Face GPTNeoX. def init_weights(module): d = torch.cuda.current_device() # Note that below doesn't work if you have buffers in the model # buffers will need to reinitialized after this call module.to_empty(device=d, recurse=False) if isinstance(module, (nn.Linear, Conv1D)): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.bias: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.padding_idx: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # Changes to FSDP wrapper. model = FSDP( model, ..., param_init_fn=init_weights ) # At this point model is initialized and sharded for sharded data parallelism.

Perhatikan bahwa pendekatan inisialisasi parameter tertunda bukanlah model agnostik. Untuk mengatasi masalah ini, Anda perlu menulis init_weights fungsi seperti yang ditunjukkan pada contoh sebelumnya agar sesuai dengan inisialisasi dalam definisi model asli, dan itu harus mencakup semua parameter model. Untuk menyederhanakan proses persiapan init_weights fungsi tersebut, SMP v2 mengimplementasikan fungsi inisialisasi ini untuk model berikut: GPT -2, GPT -J, GPT -NeoX, dan Llama dari Hugging Face Transformers. Ini torch.sagemaker.delayed_param.DelayedParamIniter API juga berfungsi dengan implementasi paralel SMP tensor, torch.sagemaker.tensor_parallel.transformer.TransformerLMHead model, yang dapat Anda panggil setelah torch.sagemaker.transform API panggilan.

Dengan menggunakan torch.sagemaker.delayed_param.DelayedParamIniterAPI, Anda dapat menyesuaikan PyTorch FSDP skrip Anda sebagai berikut. Setelah membuat model dengan bobot kosong, daftarkan torch.sagemaker.delayed_param.DelayedParamIniter API ke model, dan tentukan objeknya. Lewati objek ke param_init_fn PyTorch FSDP kelas.

from torch.sagemaker.delayed_param import DelayedParamIniter from accelerate import init_empty_weights with init_empty_weights(): model = create_model() delayed_initer = DelayedParamIniter(model) with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn() )

Catatan tentang bobot yang diikat

Saat melatih model dengan bobot terikat, kita perlu berhati-hati untuk mengikat bobot setelah menginisialisasi bobot dengan inisialisasi parameter yang tertunda. PyTorchFSDPtidak memiliki mekanisme untuk mengikat bobot setelah menginisialisasi mereka menggunakan param_init_fn seperti di atas. Untuk mengatasi kasus seperti itu, kami menambahkan API untuk mengizinkan apost_init_hook_fn, yang dapat digunakan untuk mengikat bobot. Anda dapat meneruskan fungsi apa pun di sana yang menerima modul sebagai argumen, tetapi kami juga memiliki standar yang post_param_init_fn ditentukan di DelayedParamIniter mana tie_weights metode panggilan modul jika ada. Perhatikan bahwa aman untuk selalu masuk post_param_init_fn meskipun tidak ada tie_weights metode untuk modul.

with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn(), post_param_init_fn=delayed_initer.get_post_param_init_fn() )