使用 JumpStartEstimator
类微调公开可用的基础模型
使用 SageMaker Python SDK,只需几行代码就能对内置算法或预训练模型进行微调。
-
首先,在内置算法与预训练模型表
中找到所选模型的模型 ID。 -
使用模型 ID 将训练作业定义为 JumpStart 估算器。
from sagemaker.jumpstart.estimator import JumpStartEstimator model_id =
"huggingface-textgeneration1-gpt-j-6b"
estimator = JumpStartEstimator(model_id=model_id) -
在模型上运行
estimator.fit()
,指向用于微调的训练数据。estimator.fit( {"train":
training_dataset_s3_path
, "validation":validation_dataset_s3_path
} ) -
然后,使用
deploy
方法自动部署模型进行推理。在此示例中,我们使用 Hugging Face 的 GPT-J 6B 模型。predictor = estimator.deploy()
-
然后,您就可以使用
predict
方法对已部署的模型进行推理。question =
"What is Southern California often abbreviated as?"
response = predictor.predict(question) print(response)
注意
此示例使用基础模型 GPT-J 6B,该模型适用于各种文本生成使用场景,包括问题解答、命名实体识别、摘要等。有关模型使用场景的更多信息,请参阅 可用的基础模型。
创建 JumpStartEstimator
时,您可以选择指定模型版本或实例类型。有关 JumpStartEstimator
类及其参数的更多信息,请参阅 JumpStartEstimator
检查默认实例类型
在使用 JumpStartEstimator
类对预训练模型进行微调时,您可以选择包含特定的模型版本或实例类型。所有 JumpStart 模型都有默认实例类型。使用以下代码读取默认训练实例类型:
from sagemaker import instance_types instance_type = instance_types.retrieve_default( model_id=model_id, model_version=model_version, scope=
"training"
) print(instance_type)
您可以使用 instance_types.retrieve()
方法查看特定 JumpStart 模型支持的所有实例类型。
检查默认超参数
要检查用于训练的默认超参数,可以使用 hyperparameters
类中的 retrieve_default()
方法。
from sagemaker import hyperparameters my_hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version) print(my_hyperparameters) # Optionally override default hyperparameters for fine-tuning my_hyperparameters["epoch"] = "3" my_hyperparameters["per_device_train_batch_size"] = "4" # Optionally validate hyperparameters for the model hyperparameters.validate(model_id=model_id, model_version=model_version, hyperparameters=my_hyperparameters)
有关可用超参数的更多信息,请参阅 通常支持的微调超参数。
检查默认指标定义
您还可以检查默认指标定义:
print(metric_definitions.retrieve_default(model_id=model_id, model_version=model_version))