TabTransformer hiperparameter - Amazon SageMaker AI

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

TabTransformer hiperparameter

Tabel berikut berisi subset hiperparameter yang diperlukan atau paling umum digunakan untuk algoritma Amazon SageMaker AI TabTransformer . Pengguna mengatur parameter ini untuk memfasilitasi estimasi parameter model dari data. TabTransformerAlgoritma SageMaker AI adalah implementasi dari TabTransformerpaket open-source.

catatan

Hyperparameter default didasarkan pada contoh kumpulan data di file. TabTransformer contoh buku catatan

TabTransformer Algoritma SageMaker AI secara otomatis memilih metrik evaluasi dan fungsi objektif berdasarkan jenis masalah klasifikasi. TabTransformer Algoritma mendeteksi jenis masalah klasifikasi berdasarkan jumlah label dalam data Anda. Untuk masalah regresi, metrik evaluasi adalah r kuadrat dan fungsi tujuannya adalah kesalahan kuadrat rata-rata. Untuk masalah klasifikasi biner, metrik evaluasi dan fungsi objektif keduanya adalah entropi silang biner. Untuk masalah klasifikasi multikelas, metrik evaluasi dan fungsi objektif keduanya adalah entropi silang multikelas.

catatan

Metrik TabTransformer evaluasi dan fungsi objektif saat ini tidak tersedia sebagai hiperparameter. Alih-alih, algoritme TabTransformer bawaan SageMaker AI secara otomatis mendeteksi jenis tugas klasifikasi (regresi, biner, atau multiclass) berdasarkan jumlah bilangan bulat unik di kolom label dan menetapkan metrik evaluasi dan fungsi objektif.

Nama Parameter Deskripsi
n_epochs

Jumlah zaman untuk melatih jaringan saraf dalam.

Nilai yang valid: bilangan bulat, rentang: Bilangan bulat positif.

Nilai default:5.

patience

Pelatihan akan berhenti jika satu metrik dari satu titik data validasi tidak membaik di patience babak terakhir.

Nilai yang valid: integer, range: (2,60).

Nilai default:10.

learning_rate

Tingkat di mana bobot model diperbarui setelah mengerjakan setiap batch contoh pelatihan.

Nilai yang valid: float, range: Nomor floating point positif.

Nilai default:0.001.

batch_size

Jumlah contoh disebarkan melalui jaringan.

Nilai yang valid: integer, range: (1,2048).

Nilai default:256.

input_dim

Dimensi penyematan untuk menyandikan kolom kategoris dan/atau kontinu.

Nilai yang valid: string, salah satu dari berikut ini: "16""32","64",,"128","256", atau"512".

Nilai default:"32".

n_blocks

Jumlah blok encoder Transformer.

Nilai yang valid: integer, range: (1,12).

Nilai default:4.

attn_dropout

Tingkat putus sekolah diterapkan pada lapisan Multi-Head Attention.

Nilai yang valid: float, range: (0,1).

Nilai default:0.2.

mlp_dropout

Tingkat putus sekolah diterapkan ke FeedForward jaringan dalam lapisan encoder dan lapisan MLP akhir di atas encoder Transformer.

Nilai yang valid: float, range: (0,1).

Nilai default:0.1.

frac_shared_embed

Fraksi embeddings dibagi oleh semua kategori yang berbeda untuk satu kolom tertentu.

Nilai yang valid: float, range: (0,1).

Nilai default:0.25.