公開されている基盤モデルを JumpStartEstimator
クラスでファインチューニングする
SageMaker Python SDK を使用して、組み込みアルゴリズムや事前トレーニング済みのモデルをわずか数行のコードでファインチューニングできます。
-
まず、「Built-in Algorithms with pre-trained Model Table
」で、選択したモデルのモデル ID を調べます。 -
そのモデル ID を使用して、トレーニングジョブを JumpStart 推定器として定義します。
from sagemaker.jumpstart.estimator import JumpStartEstimator model_id =
"huggingface-textgeneration1-gpt-j-6b"
estimator = JumpStartEstimator(model_id=model_id) -
モデルで
estimator.fit()
を実行し、ファインチューニングに使用するトレーニングデータを指定します。estimator.fit( {"train":
training_dataset_s3_path
, "validation":validation_dataset_s3_path
} ) -
次に、
deploy
メソッドを使用して、推論用にモデルを自動的にデプロイします。この例では、Hugging Face の GPT-J 6B モデルを使用します。predictor = estimator.deploy()
-
その後、
predict
メソッドを使用して、デプロイしたモデルで推論を実行できます。question =
"What is Southern California often abbreviated as?"
response = predictor.predict(question) print(response)
注記
この例では、基盤モデル GPT-J 6B を使用しています。このモデルは、質疑応答、固有表現抽出、要約など、幅広いテキスト生成のユースケースに適しています。モデルのユースケースの詳細については、「利用可能な基盤モデル」を参照してください。
JumpStartEstimator
の作成時に、モデルのバージョンやインスタンスのタイプをオプションで指定できます。JumpStartEstimator
クラスとそのパラメータの詳細については、「JumpStartEstimator
デフォルトのインスタンスタイプを確認する
事前トレーニング済みのモデルを JumpStartEstimator
クラスを使用してファインチューニングするときに、モデルのバージョンやインスタンスのタイプをオプションで指定することができます。すべての JumpStart モデルにはデフォルトのインスタンスタイプがあります。次のコードを使用して、デフォルトのトレーニングインスタンスタイプを取得してください。
from sagemaker import instance_types instance_type = instance_types.retrieve_default( model_id=model_id, model_version=model_version, scope=
"training"
) print(instance_type)
instance_types.retrieve()
メソッドを使用して、特定の JumpStart モデルでサポートされているすべてのインスタンスタイプを確認できます。
デフォルトのハイパーパラメータを確認する
トレーニングに使用されるデフォルトのハイパーパラメータを確認するには、hyperparameters
クラスから retrieve_default()
メソッドを使用できます。
from sagemaker import hyperparameters my_hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version) print(my_hyperparameters) # Optionally override default hyperparameters for fine-tuning my_hyperparameters["epoch"] = "3" my_hyperparameters["per_device_train_batch_size"] = "4" # Optionally validate hyperparameters for the model hyperparameters.validate(model_id=model_id, model_version=model_version, hyperparameters=my_hyperparameters)
利用可能なハイパーパラメータの詳細については、「一般的にサポートされているファインチューニングのハイパーパラメータ」を参照してください。
デフォルトのメトリクス定義を確認する
デフォルトのメトリクス定義を確認することもできます。
print(metric_definitions.retrieve_default(model_id=model_id, model_version=model_version))