用于微调的基础模型和超参数 - Amazon SageMaker

用于微调的基础模型和超参数

基础模型的计算成本很高,并且是在大型的、未标注的语料库上训练的。微调预训练的基础模型是一种经济实惠的方式,可以利用其广泛的功能,同时根据自己的小型语料库来自定义模型。微调是一种自定义方法,它涉及进一步的训练,并且会改变模型的权重。

如果您有以下要求,微调可能会很有用:

  • 根据特定业务需求自定义模型

  • 让模型可以成功处理特定于领域的语言,例如行业术语、技术术语或其他专业词汇

  • 针对特定任务增强性能

  • 在应用中提供准确、相对的和感知上下文的响应

  • 更真实、毒性更小、更符合具体要求的响应

根据使用案例和所选的基础模型,您可以采用两种主要方法进行微调。

  1. 如果您有兴趣根据特定于领域数据微调模型,请参阅利用领域适应性微调大型语言模型(LLM)

  2. 如果您对使用提示和响应样本进行基于指令的微调感兴趣,请参阅使用提示指令微调大型语言模型(LLM)

可进行微调的基础模型

您可以对以下任何一种 JumpStart 基础模型进行微调:

  • Bloom 3B

  • Bloom 7B1

  • BloomZ 3B FP16

  • BloomZ 7B1 FP16

  • Code Llama 13B

  • Code Llama 13B Python

  • Code Llama 34B

  • Code Llama 34B Python

  • Code Llama 70B

  • Code Llama 70B Python

  • Code Llama 7B

  • Code Llama 7B Python

  • CyberAgentLM2-7B-Chat (CALM2-7B-Chat)

  • Falcon 40B BF16

  • Falcon 40B Instruct BF16

  • Falcon 7B BF16

  • Falcon 7B Instruct BF16

  • Flan-T5 Base

  • Flan-T5 Large

  • Flan-T5 Small

  • Flan-T5 XL

  • Flan-T5 XXL

  • Gemma 2B

  • Gemma 2B Instruct

  • Gemma 7B

  • Gemma 7B Instruct

  • GPT-2 XL

  • GPT-J 6B

  • GPT-Neo 1.3B

  • GPT-Neo 125M

  • GPT-NEO 2.7B

  • LightGPT Instruct 6B

  • Llama 2 13B

  • Llama 2 13B Chat

  • Llama 2 13B Neuron

  • Llama 2 70B

  • Llama 2 70B Chat

  • Llama 2 7B

  • Llama 2 7B Chat

  • Llama 2 7B Neuron

  • Mistral 7B

  • Mixtral 8x7B

  • Mixtral 8x7B Instruct

  • RedPajama INCITE Base 3B V1

  • RedPajama INCITE Base 7B V1

  • RedPajama INCITE Chat 3B V1

  • RedPajama INCITE Chat 7B V1

  • RedPajama INCITE Instruct 3B V1

  • RedPajama INCITE Instruct 7B V1

  • Stable Diffusion 2.1

通常支持的微调超参数

微调时,不同的基础模型支持不同的超参数。以下是常用的超参数,可在训练过程中进一步自定义模型:

推理参数 描述

epoch

模型在训练过程中通过微调数据集的次数。必须是大于 1 的整数。

learning_rate

完成每批微调训练样本后,更新模型权重的速度。必须是大于 0 的正浮点数。

instruction_tuned

是否对模型进行指令训练。必须为 'True''False'

per_device_train_batch_size

用于训练的每个 GPU 内核或 CPU 的批量大小。其值必须为正整数。

per_device_eval_batch_size

用于评估的每个 GPU 内核或 CPU 的批量大小。其值必须为正整数。

max_train_samples

为了调试或加快训练速度,请将训练样本的数量截断为该值。值 -1 表示模型使用了所有训练样本。必须是正整数或 -1。

max_val_samples

为了调试或加快训练速度,请将验证样本的数量截断为该值。值 -1 表示模型使用了所有验证样本。必须是正整数或 -1。

max_input_length

令牌化后输入序列的最大总长度。长度超过此值的序列将被截断。如果为 -1,max_input_length 将被设置为 1024 和分词器定义的 model_max_length 的最小值。如果设置为正值,max_input_length 将被设置为所提供值和分词器定义的 model_max_length 的最小值。必须是正整数或 -1。

validation_split_ratio

如果没有验证通道,则训练 - 验证的比例将从训练数据中拆分。必须介于 0 和 1 之间。

train_data_split_seed

如果不存在验证数据,则将输入的训练数据随机拆分为模型使用的训练数据和验证数据。必须是整数。

preprocessing_num_workers

用于预处理的进程数。如果 None,则使用主进程进行预处理。

lora_r

低秩适应 (LoRA) r 值,作为权重更新的缩放因子。其值必须为正整数。

lora_alpha

低秩适应 (LoRA) 阿尔法值,作为权重更新的缩放因子。一般是 lora_r 的 2 到 4 倍。其值必须为正整数。

lora_dropout

低秩适应 (LoRA) 层的释放参数必须是介于 0 和 1 之间的正浮点数。

int8_quantization

如果 True,则模型将以 8 位精度加载,以进行训练。

enable_fsdp

如果 True,则训练使用完全分片数据并行。

在 Studio 中微调模型时,您可以指定超参数值。有关更多信息,请参阅 在 Studio 中微调模型

使用 SageMaker Python SDK 微调模型时,您也可以覆盖默认的超参数值。有关更多信息,请参阅 使用 JumpStartEstimator 类微调公开可用的基础模型