使用 JumpStartEstimator類別微調公開可用的基礎模型 - Amazon SageMaker

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

使用 JumpStartEstimator類別微調公開可用的基礎模型

您可以使用 ,在幾行程式碼中微調內建演算法或預先訓練的模型 SageMaker Python SDK.

  1. 首先,在具有預先訓練模型表 的內建演算法中找到您選擇的模型的模型 ID。

  2. 使用模型 ID,將訓練任務 JumpStart定義為估算器。

    from sagemaker.jumpstart.estimator import JumpStartEstimator model_id = "huggingface-textgeneration1-gpt-j-6b" estimator = JumpStartEstimator(model_id=model_id)
  3. 在模型estimator.fit()上執行,指向要用於微調的訓練資料。

    estimator.fit( {"train": training_dataset_s3_path, "validation": validation_dataset_s3_path} )
  4. 然後,使用 deploy方法自動部署模型以進行推論。在此範例中,我們使用來自 的 GPT-J 6B 模型 Hugging Face.

    predictor = estimator.deploy()
  5. 然後,您可以使用 predict方法,使用部署的模型執行推論。

    question = "What is Southern California often abbreviated as?" response = predictor.predict(question) print(response)
注意

此範例使用基礎模型 GPT-J 6B ,適用於各種文字產生使用案例,包括問題回答、具名實體識別、摘要等。如需模型使用案例的詳細資訊,請參閱 可用的基礎模型

您可以在建立 時選擇性地指定模型版本或執行個體類型JumpStartEstimator。如需JumpStartEstimator 類別及其參數的詳細資訊,請參閱 JumpStartEstimator

檢查預設執行個體類型

在使用 JumpStartEstimator 類別微調預先訓練的模型時,您可以選擇包含特定的模型版本或執行個體類型。所有 JumpStart 模型都有預設執行個體類型。使用下列程式碼擷取預設訓練執行個體類型:

from sagemaker import instance_types instance_type = instance_types.retrieve_default( model_id=model_id, model_version=model_version, scope="training") print(instance_type)

您可以使用 instance_types.retrieve()方法查看指定 JumpStart 模型的所有支援執行個體類型。

檢查預設超參數

若要檢查用於訓練的預設超參數,您可以從 hyperparameters類別使用 retrieve_default()方法。

from sagemaker import hyperparameters my_hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version) print(my_hyperparameters) # Optionally override default hyperparameters for fine-tuning my_hyperparameters["epoch"] = "3" my_hyperparameters["per_device_train_batch_size"] = "4" # Optionally validate hyperparameters for the model hyperparameters.validate(model_id=model_id, model_version=model_version, hyperparameters=my_hyperparameters)

如需可用超參數的詳細資訊,請參閱 通常支援的微調超參數

檢查預設指標定義

您也可以檢查預設指標定義:

print(metric_definitions.retrieve_default(model_id=model_id, model_version=model_version))