本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。
演算法的 CatBoost輸入和輸出介面
梯度提升在表格式資料中操作,含有代表觀察的行、還有一個代表目標變數或標籤的欄,而剩下的欄則代表功能。
訓練和推論 CatBoost CSV支援的 SageMaker 實作:
-
對於訓練 ContentType,有效輸入必須是文字/csv。
-
對於推論 ContentType,有效的輸入必須是文字/csv。
注意
針對CSV訓練,演算法會假設目標變數位於第一欄,且 CSV沒有標頭記錄。
對於CSV推論,演算法會假設CSV輸入沒有標籤欄。
訓練資料、驗證資料和分類功能的輸入格式
請注意如何格式化訓練資料以輸入 CatBoost 模型。您必須提供包含訓練和驗證資料之 Amazon S3 儲存貯體的路徑。您也可以內涵分類功能清單。同時使用training
和validation
通道來提供您的輸入資料。或者,您可以只使用training
頻道。
同時使用training
和validation
通道
您可以透過兩個 S3 路徑提供輸入資料,一個用於training
通道,另一個用於validation
通道。每個 S3 路徑可以是指向一或多個CSV檔案的 S3 字首,也可以是指向一個特定CSV檔案的完整 S3 路徑。目標變數應位於CSV檔案的第一欄。預測變量 (功能) 應該位於其餘列中。如果為 training
或 validation
通道提供多個CSV檔案, CatBoost 則演算法會串連檔案。驗證資料用於計算每次增加迭代結束時的驗證分數。當驗證分數停止改善時,會套用提前停止。
如果您的預測器包含類別功能,您可以提供categorical_index.json
名為 JSON 的檔案,其位置與訓練資料檔案相同。如果您為類別功能提供JSON檔案,您的training
頻道必須指向 S3 字首,而不是特定CSV檔案。這個文件應該包含一個 Python 字典,其中索引鍵是字串 "cat_index_list"
,該值是唯一整數的清單。值清單中的每個整數應指出訓練資料CSV檔案中對應類別特徵的資料欄索引。每個值都應該是一個正整數 (大於零,因為零表示目標值)、小於 Int32.MaxValue
(2147483647),且小於資料欄的總數。應該只有一個類別索引JSON檔案。
僅使用training
通道:
或者,您也可以透過training
通道的單一 S3 路徑提供輸入資料。此 S3 路徑應指向具有名為 的子目錄training/
的目錄,該子目錄包含一或多個CSV檔案。您可以選擇性地在名為 的相同位置包含另一個子目錄validation/
,該位置也具有一或多個CSV檔案。如果未提供驗證資料,則會隨機抽樣 20% 的訓練資料,做為驗證資料。如果您的預測器包含類別功能,您可以提供名為 JSON的檔案,其位置categorical_index.json
與資料子目錄相同。
注意
對於CSV訓練輸入模式,演算法可用的總記憶體 (執行個體計數乘以 中可用的記憶體InstanceType
) 必須能夠保留訓練資料集。
SageMaker CatBoost 使用 catboost.CatBoostClassifier
和 catboost.CatBoostRegressor
模組來序列化或還原序列化模型,可用於儲存或載入模型。
若要使用使用 訓練 SageMaker CatBoost 的模型 catboost
-
使用以下 Python 程式碼:
import tarfile from catboost import CatBoostClassifier t = tarfile.open('model.tar.gz', 'r:gz') t.extractall() file_path = os.path.join(model_file_path, "model") model = CatBoostClassifier() model.load_model(file_path) # prediction with test data # dtest should be a pandas DataFrame with column names feature_0, feature_1, ..., feature_d pred = model.predict(
dtest
)