TabTransformer hiperparâmetros - Amazon SageMaker

As traduções são geradas por tradução automática. Em caso de conflito entre o conteúdo da tradução e da versão original em inglês, a versão em inglês prevalecerá.

TabTransformer hiperparâmetros

A tabela a seguir contém o subconjunto de hiperparâmetros que são necessários ou mais comumente usados para o algoritmo da Amazon SageMaker TabTransformer . Os usuários definem esses parâmetros para facilitar a estimativa dos parâmetros do modelo a partir dos dados. O SageMaker TabTransformer algoritmo é uma implementação do TabTransformerpacote de código aberto.

nota

Os hiperparâmetros padrão são baseados em conjuntos de dados de exemplo no TabTransformer cadernos de amostra.

O SageMaker TabTransformer algoritmo escolhe automaticamente uma métrica de avaliação e uma função objetiva com base no tipo de problema de classificação. O TabTransformer algoritmo detecta o tipo de problema de classificação com base no número de rótulos em seus dados. Para problemas de regressão, a métrica de avaliação é o r quadrático e a função objetivo é o erro quadrático médio. Para problemas de classificação binária, a métrica de avaliação e a função objetiva são ambas entropia cruzada binária. Para problemas de classificação multiclasse, a métrica de avaliação e a função objetiva são ambas entropia cruzada multiclasse.

nota

A métrica de TabTransformer avaliação e as funções objetivas não estão atualmente disponíveis como hiperparâmetros. Em vez disso, o algoritmo SageMaker TabTransformer integrado detecta automaticamente o tipo de tarefa de classificação (regressão, binária ou multiclasse) com base no número de números inteiros exclusivos na coluna do rótulo e atribui uma métrica de avaliação e uma função objetiva.

Nome do parâmetro Descrição
n_epochs

Número de épocas para treinar a rede neural profunda.

Valores válidos: inteiro, intervalo: inteiro positivo.

Valor padrão: 5.

patience

O treinamento será interrompido se uma métrica de um ponto de dados de validação não melhorar na última rodada patience.

Valores válidos: flutuante, intervalo: (2, 60).

Valor padrão: 10.

learning_rate

A taxa na qual os pesos do modelo são atualizados depois de analisar cada lote de exemplos de treinamento.

Valores válidos: flutuante, intervalo: número de ponto flutuante positivo.

Valor padrão: 0.001.

batch_size

O número de exemplos propagados pela rede.

Valores válidos: flutuante, intervalo: (1, 2048).

Valor padrão: 256.

input_dim

A dimensão das incorporações para codificar as colunas categóricas e/ou contínuas.

Valores válidos: string, qualquer um dos seguintes: "16", "32", "64", "128", "256" ou "512".

Valor padrão: "32".

n_blocks

O número de blocos do codificador Transformer.

Valores válidos: flutuante, intervalo: (1, 12).

Valor padrão: 4.

attn_dropout

Taxa de desistência aplicada às camadas Multi-Head Attention.

Valores válidos: flutuante. Intervalo: (0, 1).

Valor padrão: 0.2.

mlp_dropout

Taxa de abandono aplicada à FeedForward rede dentro das camadas do codificador e às MLP camadas finais na parte superior dos codificadores do Transformer.

Valores válidos: flutuante. Intervalo: (0, 1).

Valor padrão: 0.1.

frac_shared_embed

A fração de incorporações compartilhadas por todas as diferentes categorias de uma coluna específica.

Valores válidos: flutuante, intervalo: (0, 1).

Valor padrão: 0.25.