使用 JumpStartEstimator 类微调公开可用的基础模型 - Amazon SageMaker

使用 JumpStartEstimator 类微调公开可用的基础模型

使用 SageMaker Python SDK,只需几行代码就能对内置算法或预训练模型进行微调。

  1. 首先,在内置算法与预训练模型表中找到所选模型的模型 ID。

  2. 使用模型 ID 将训练作业定义为 JumpStart 估算器。

    from sagemaker.jumpstart.estimator import JumpStartEstimator model_id = "huggingface-textgeneration1-gpt-j-6b" estimator = JumpStartEstimator(model_id=model_id)
  3. 在模型上运行 estimator.fit(),指向用于微调的训练数据。

    estimator.fit( {"train": training_dataset_s3_path, "validation": validation_dataset_s3_path} )
  4. 然后,使用 deploy 方法自动部署模型进行推理。在此示例中,我们使用 Hugging Face 的 GPT-J 6B 模型。

    predictor = estimator.deploy()
  5. 然后,您就可以使用 predict 方法对已部署的模型进行推理。

    question = "What is Southern California often abbreviated as?" response = predictor.predict(question) print(response)
注意

此示例使用基础模型 GPT-J 6B,该模型适用于各种文本生成使用场景,包括问题解答、命名实体识别、摘要等。有关模型使用场景的更多信息,请参阅 可用的基础模型

创建 JumpStartEstimator 时,您可以选择指定模型版本或实例类型。有关 JumpStartEstimator 类及其参数的更多信息,请参阅 JumpStartEstimator

检查默认实例类型

在使用 JumpStartEstimator 类对预训练模型进行微调时,您可以选择包含特定的模型版本或实例类型。所有 JumpStart 模型都有默认实例类型。使用以下代码读取默认训练实例类型:

from sagemaker import instance_types instance_type = instance_types.retrieve_default( model_id=model_id, model_version=model_version, scope="training") print(instance_type)

您可以使用 instance_types.retrieve() 方法查看特定 JumpStart 模型支持的所有实例类型。

检查默认超参数

要检查用于训练的默认超参数,可以使用 hyperparameters 类中的 retrieve_default() 方法。

from sagemaker import hyperparameters my_hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version) print(my_hyperparameters) # Optionally override default hyperparameters for fine-tuning my_hyperparameters["epoch"] = "3" my_hyperparameters["per_device_train_batch_size"] = "4" # Optionally validate hyperparameters for the model hyperparameters.validate(model_id=model_id, model_version=model_version, hyperparameters=my_hyperparameters)

有关可用超参数的更多信息,请参阅 通常支持的微调超参数

检查默认指标定义

您还可以检查默认指标定义:

print(metric_definitions.retrieve_default(model_id=model_id, model_version=model_version))