公開されている基盤モデルを JumpStartEstimator クラスでファインチューニングする - Amazon SageMaker

公開されている基盤モデルを JumpStartEstimator クラスでファインチューニングする

SageMaker Python SDK を使用して、組み込みアルゴリズムや事前トレーニング済みのモデルをわずか数行のコードでファインチューニングできます。

  1. まず、「Built-in Algorithms with pre-trained Model Table」で、選択したモデルのモデル ID を調べます。

  2. そのモデル ID を使用して、トレーニングジョブを JumpStart 推定器として定義します。

    from sagemaker.jumpstart.estimator import JumpStartEstimator model_id = "huggingface-textgeneration1-gpt-j-6b" estimator = JumpStartEstimator(model_id=model_id)
  3. モデルで estimator.fit() を実行し、ファインチューニングに使用するトレーニングデータを指定します。

    estimator.fit( {"train": training_dataset_s3_path, "validation": validation_dataset_s3_path} )
  4. 次に、deploy メソッドを使用して、推論用にモデルを自動的にデプロイします。この例では、Hugging Face の GPT-J 6B モデルを使用します。

    predictor = estimator.deploy()
  5. その後、predict メソッドを使用して、デプロイしたモデルで推論を実行できます。

    question = "What is Southern California often abbreviated as?" response = predictor.predict(question) print(response)
注記

この例では、基盤モデル GPT-J 6B を使用しています。このモデルは、質疑応答、固有表現抽出、要約など、幅広いテキスト生成のユースケースに適しています。モデルのユースケースの詳細については、「利用可能な基盤モデル」を参照してください。

JumpStartEstimator の作成時に、モデルのバージョンやインスタンスのタイプをオプションで指定できます。JumpStartEstimator クラスとそのパラメータの詳細については、「JumpStartEstimator」を参照してください。

デフォルトのインスタンスタイプを確認する

事前トレーニング済みのモデルを JumpStartEstimator クラスを使用してファインチューニングするときに、モデルのバージョンやインスタンスのタイプをオプションで指定することができます。すべての JumpStart モデルにはデフォルトのインスタンスタイプがあります。次のコードを使用して、デフォルトのトレーニングインスタンスタイプを取得してください。

from sagemaker import instance_types instance_type = instance_types.retrieve_default( model_id=model_id, model_version=model_version, scope="training") print(instance_type)

instance_types.retrieve() メソッドを使用して、特定の JumpStart モデルでサポートされているすべてのインスタンスタイプを確認できます。

デフォルトのハイパーパラメータを確認する

トレーニングに使用されるデフォルトのハイパーパラメータを確認するには、hyperparameters クラスから retrieve_default() メソッドを使用できます。

from sagemaker import hyperparameters my_hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version) print(my_hyperparameters) # Optionally override default hyperparameters for fine-tuning my_hyperparameters["epoch"] = "3" my_hyperparameters["per_device_train_batch_size"] = "4" # Optionally validate hyperparameters for the model hyperparameters.validate(model_id=model_id, model_version=model_version, hyperparameters=my_hyperparameters)

利用可能なハイパーパラメータの詳細については、「一般的にサポートされているファインチューニングのハイパーパラメータ」を参照してください。

デフォルトのメトリクス定義を確認する

デフォルトのメトリクス定義を確認することもできます。

print(metric_definitions.retrieve_default(model_id=model_id, model_version=model_version))