Using DataLoader with num_workers
greater than 0
can cause increased memory consumption over time when iterating over native Python objects such as list
or dict
. Pytorch
uses multiprocessing in this scenario placing the data in shared memory. However, reference counting triggers copy-on-writes which over time increases the memory consumption. This behavior resembles a memory-leak. Using pandas
, numpy
, or pyarrow
arrays solves this problem.
1def pytorch_data_loader_with_multiple_workers_noncompliant():
2 import torch
3 from torch.utils.data import DataLoader
4 import numpy as np
5 sampler = InfomaxNodeRecNeighborSampler(g, [fanout] * (n_layers),
6 device=device, full_neighbor=True)
7 pr_node_ids = list(sampler.hetero_map.keys())
8 pr_val_ind = list(np.random.choice(len(pr_node_ids),
9 int(len(pr_node_ids) * val_pct),
10 replace=False))
11 pr_train_ind = list(set(list(np.arange(len(pr_node_ids))))
12 .difference(set(pr_val_ind)))
13
14 # Noncompliant: num_workers value is 8 and native python 'list'
15 # is used here to store the dataset.
16 loader = DataLoader(dataset=pr_train_ind,
17 batch_size=batch_size,
18 collate_fn=sampler.sample_blocks,
19 shuffle=True,
20 num_workers=8)
21
22 optimizer = torch.optim.Adam(model.parameters(),
23 lr=lr,
24 weight_decay=l2norm)
25
26 # training loop
27 print("start training...")
28
29 for epoch in range(n_epochs):
30 model.train()
1def pytorch_data_loader_with_multiple_workers_compliant(args):
2 import torch.optim
3 import torchvision.datasets as datasets
4 # Data loading code
5 traindir = os.path.join(args.data, 'train')
6 valdir = os.path.join(args.data, 'val')
7 normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
8 std=[0.229, 0.224, 0.225])
9
10 train_dataset = datasets.ImageFolder(traindir, imagenet_transforms)
11 train_sampler = torch.utils.data.distributed\
12 .DistributedSampler(train_dataset)
13
14 # Compliant: args.workers value is assigned to num_workers,
15 # but native python 'list/dict' is not used here to store the dataset.
16 train_loader = torch.utils.data.DataLoader(train_dataset,
17 batch_size=args.batch_size,
18 shuffle=(train_sampler is None),
19 num_workers=args.workers,
20 pin_memory=True,
21 sampler=train_sampler)