Cached Dataset

When samples in a dataset need tedious preprocessing, or reading the dataset itself is slow, they could become a major bottleneck of the whole training process. Bagua provides cached dataset to speedup this process by caching data samples in memory, so that reading these samples after the first time can be much faster.

Usage

CachedDataset is a Pytorch custom dataset (see Creating a Custom Dataset for your files). It wraps a Pytorch dataset and caches its samples into a distributed key-value store. We can specify the backend to use on the initialization of a cached dataset. Currently Redis1 is supported, which is an in-memory data store.

By default, cached dataset will spawn a new Redis instance on each worker node, and data is sharded across all Redis instances on all nodes in the Bagua job. We can specify the maximum memory limit to use for each node, by passing capacity_per_node to CachedDataset.

The following is an example to use a Redis-backend cached dataset, the maximum memory limit on each node is 400GB. A 4-node Bagua job can have a maximum memory limit of 1.6TB.

from bagua.torch_api.contrib import CachedDataset

cache_dataset = CachedDataset(
 dataset,
 backend="redis",
 dataset_name="ds",
 capacity_per_node=400 * 1024 * 1024 * 1024,
)
dataloader = torch.utils.data.DataLoader(cached_dataset)

for i, (input, target) in enumerate(dataloader):
    ...

By setting cluster_mode=False, we can restrict each training node to use only its local Redis instance.

cache_dataset = CachedDataset(
 dataset,
 backend="redis",
 dataset_name="ds",
 cluster_mode=False,
 capacity_per_node=400 * 1024 * 1024 * 1024,
)

We can also use existing Redis servers as the backend store by passing a list of host information of redis servers to hosts.

hosts = [
    {"host": "192.168.1.0", "port": "7000"},
    {"host": "192.168.1.1", "port": "7000"},
]
cache_dataset = CachedDataset(
    dataset,
    backend="redis",
    dataset_name="ds",
    hosts=hosts,
)
1

To try cached dataset out, users need to install redis-py and redis-server first, or simply use the docker images we provided.

Multiple cached datasets

Multiple cached datasets share the same backend store, thus we need to specify a unique name for each dataset to avoid overwriting samples from each other.

from bagua.torch_api.contrib import CachedDataset

dataset1 = ...
dataset2 = ...

cache_dataset1 = CachedDataset(
 dataset1,
 backend="redis",
 dataset_name="ds1",
 capacity_per_node=400 * 1024 * 1024 * 1024,
)

cache_dataset2 = CachedDataset(
 dataset2,
 backend="redis",
 dataset_name="ds2",
 capacity_per_node=400 * 1024 * 1024 * 1024,
)

It should be noted that Redis instance will only be spawned once on each node, and the other cached dataset will reuse the existing Redis instance. Only parameters2 to spawn the first Redis instance will take effect. In the example above, the maximum memory limit on each node will be 400GB even if we set capacity_per_node to a different number when initializing cache_dataset2.

2

cluster_mode and capacity_per_node are used to spawn new Redis instances when hosts=None. See RedisStore for more information.

Dataset with augmentation

For dataset with augmentation, we can not use cached dataset directly. Instead, we can define our own custom dataset using CachedLoader3. Here is an example.

import torch.utils.data as data
from bagua.torch_api.contrib import CacheLoader


class PanoHand(data.Dataset):
    def __init__(self):
        super(PanoHand, self).__init__()

        self.img_list = ...
        self.cache_loader = CacheLoader(
            backend="redis",
            capacity_per_node=400 * 1024 * 1024 * 1024,
            hosts=None,
            cluster_mode=True,
        )

    def __getitem__(self, idx):
        return self.get_training_sample(idx)

    def _process_fn(self, idx):
        # preprocessing to produce deterministic result
        ...

    def get_training_sample(self, idx):
        ret = self.cache_loader.get(idx, self._process_fn)

        # data augmentation
        ...


    def __len__(self):
        return len(self.img_list)

3

CachedDataset is built upon CacheLoader as well.

Benchmark result

On an important internal 3D mesh training task, where the data preprocessing becomes the major bottleneck, with one NVIDIA Tesla V100 GPU, using cached loader can reduce the end-to-end training time by more than 60%, only incurring a small overhead to write to the key-value store in the first epoch.

w/o Cached Loaderw. Cached Loader
Epoch #1 Time (s)63756473
Epoch #2 Time (s)63062264
Epoch #3 Time (s)63212240