本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
已准备好在基 AWS Deep Learning AMIs 于 Arm64 处理器的情况下使用 GPUs,并针对以下方面进行了优化。 PyTorch ARM64 GPU PyTorch DLAMI 包括一个预先配置了、和的 Python 环境 TorchVision
验证 PyTorch Python 环境
使用以下命令来连接您的 G5g 实例并激活基础 Conda 环境:
source activate base
您的命令提示符应表明您正在基本 Conda 环境中工作,该环境包含 PyTorch TorchVision、和其他库。
(base) $
验证 PyTorch 环境的默认刀具路径:
(base) $ which python
(base) $ which pip
(base) $ which conda
(base) $ which mamba
>>> import torch, torchvision
>>> torch.__version__
>>> torchvision.__version__
>>> v = torch.autograd.Variable(torch.randn(10, 3, 224, 224))
>>> v = torch.autograd.Variable(torch.randn(10, 3, 224, 224)).cuda()
>>> assert isinstance(v, torch.Tensor)
使用运行训练示例 PyTorch
运行示例 MNIST 训练作业:
git clone https://github.com/pytorch/examples.git
cd examples/mnist
python main.py
您的输出应类似于以下内容:
...
Train Epoch: 14 [56320/60000 (94%)] Loss: 0.021424
Train Epoch: 14 [56960/60000 (95%)] Loss: 0.023695
Train Epoch: 14 [57600/60000 (96%)] Loss: 0.001973
Train Epoch: 14 [58240/60000 (97%)] Loss: 0.007121
Train Epoch: 14 [58880/60000 (98%)] Loss: 0.003717
Train Epoch: 14 [59520/60000 (99%)] Loss: 0.001729
Test set: Average loss: 0.0275, Accuracy: 9916/10000 (99%)
使用运行推理示例 PyTorch
使用以下命令下载预训练的 densenet161 模型并使用以下命令运行推理: TorchServe
# Set up TorchServe
cd $HOME
git clone https://github.com/pytorch/serve.git
mkdir -p serve/model_store
cd serve
# Download a pre-trained densenet161 model
wget https://download.pytorch.org/models/densenet161-8d451a50.pth >/dev/null
# Save the model using torch-model-archiver
torch-model-archiver --model-name densenet161 \
--version 1.0 \
--model-file examples/image_classifier/densenet_161/model.py \
--serialized-file densenet161-8d451a50.pth \
--handler image_classifier \
--extra-files examples/image_classifier/index_to_name.json \
--export-path model_store
# Start the model server
torchserve --start --no-config-snapshots \
--model-store model_store \
--models densenet161=densenet161.mar &> torchserve.log
# Wait for the model server to start
sleep 30
# Run a prediction request
curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg
您的输出应类似于以下内容:
{
"tiger_cat": 0.4693363308906555,
"tabby": 0.4633873701095581,
"Egyptian_cat": 0.06456123292446136,
"lynx": 0.0012828150065615773,
"plastic_bag": 0.00023322898778133094
}
使用以下命令来注销 densenet161 模型并停止服务器:
curl -X DELETE http://localhost:8081/models/densenet161/1.0 torchserve --stop
您的输出应类似于以下内容:
{
"status": "Model \"densenet161\" unregistered"
}
TorchServe has stopped.