텐서 병렬화 - Amazon SageMaker

텐서 병렬화

텐서 병렬 처리는 특정 모델 가중치, 그라디언트, 옵티마이저 상태가 디바이스 간에 분할되는 일종의 모델 병렬 처리입니다. 개별 가중치는 그대로 유지하되 가중치 세트, 그라데이션 또는 여러 기기에 걸친 옵티마이저는 분할하는 파이프라인 병렬 처리와 달리 텐서 병렬 처리는 개별 가중치를 분할합니다. 여기에는 일반적으로 해당 모델의 특정 연산, 모듈 또는 계층에 대한 분산 계산이 포함됩니다.

단일 파라미터가 대부분의 GPU 메모리를 소비하는 경우(예: 어휘 크기가 큰 대형 임베딩 테이블 또는 클래스 개수가 많은 대규모 소프트맥스 레이어) 텐서 병렬 처리가 필요합니다. 이 경우 이렇게 큰 텐서 또는 연산을 원자 단위로 처리하는 것은 비효율적이며 메모리 부하의 균형을 맞추는 데 방해가 됩니다.

SMP v2는 텐서 병렬 처리를 위한 구현을 위해 Transformer Engine과 통합되며 PyTorch FSDP API 위에서 실행됩니다. PyTorch FSDP 및 SMP 텐서 병렬 처리를 동시에 활성화하고 최상의 성능을 위한 최상의 모델 병렬 처리를 결정할 수 있습니다.

실제로 텐서 병렬 처리는 다음 시나리오에서 특히 유용합니다.

  • 긴 컨텍스트 길이로 훈련하면 FSDP만으로 활성화 메모리가 높아집니다.

  • 글로벌 배치 크기가 원하는 한도를 초과하는 매우 큰 클러스터로 훈련하는 경우.

SMP 텐서 병렬 처리와 호환되는 Hugging Face 트랜스포머 모델

SMP v2는 현재 다음 Hugging Face 트랜스포머 모델에 대한 텐서 병렬 처리를 지원합니다.

이러한 모델에 텐서 병렬 처리를 적용하기 위한 참조 구성은 구성 팁 섹션을 참조하세요.

텐서 병렬 처리 구성

tensor_parallel_degree의 경우 텐서 병렬 처리 정도에 대한 값을 선택합니다. 값은 클러스터의 GPU 균등하게 나누어야 합니다. 예를 들어 GPU가 8개인 인스턴스를 사용하는 동안 모델을 샤딩하려면 2, 4 또는 8개 GPU 선택합니다. 적은 숫자로 시작하고 모델이 GPU 메모리에 적합할 때까지 점진적으로 늘리는 것이 좋습니다.

다음 코드 조각은 SageMaker 모델 병렬 처리 라이브러리 v2 사용에 도입된 2단계 프로세스를 따르면서 훈련 스크립트에 SMP 초기화 모듈 torch.sagemaker.init()을 추가하고 훈련 작업 시작 관리자를 위한 JSON 형식으로 SMP 구성 사전을 설정하는 방법을 보여줍니다. PyTorch 모델 또는 PyTorch FSDP 구성을 변경할 필요가 없습니다. tensor_parallel_degreerandom_seed 파라미터에 대한 자세한 내용은 SMP v2 코어 기능 구성 파라미터 단원을 참조하세요.

SMP 구성

{ "tensor_parallel_degree": 8, "random_seed": 0 }

훈련 스크립트에서

torch.sagemaker.init()로 초기화하여 SMP v2를 활성화하고 torch.sagemaker.transform API로 모델을 래핑합니다.

import torch.sagemaker as tsm tsm.init() from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_config(..) model = tsm.transform(model)

Hugging Face 트랜스포머 체크포인트 저장 및 로드

SMP 라이브러리가 모델을 변환하면 모델의 상태 사전(state_dict)이 변경됩니다. 즉, 모델이 원래 Hugging Face 트랜스포머 체크포인트 기능과 호환되지 않습니다. 이를 처리하기 위해 SMP 라이브러리는 Hugging Face 트랜스포머 표현의 변환된 모델에서 체크포인트를 저장하는 API와 미세 조정을 위한 Hugging Face 트랜스포머 모델 체크포인트를 로드하는 torch.sagemaker.transform API를 제공합니다.

SMP v2의 텐서 병렬 처리 기능을 사용하는 동안 체크포인트를 저장하는 방법에 대한 자세한 내용은 SMP를 사용한 체크포인트 지정 섹션을 참조하세요.

SMP v2의 텐서 병렬 처리 기능을 적용하는 모델을 미세 조정하는 방법에 대한 자세한 내용은 미세 조정 섹션을 참조하세요.