from collections.abc import Iterator
import torch
import torch.utils.data as utils
from torch import Tensor
print('PyTorch version:', torch.__version__)Chapter 2.4 Data Loading in PyTorch: Dataset, DataLoader, and Batching
In the previous three sections, we discussed PyTorch’s automatic differentiation mechanism. We know that the forward pass builds a computation graph, and the backward pass sends gradients back along that graph. We also know that during validation and inference, no_grad() or inference_mode() can be used to turn off gradient recording.
However, a complete training process is not only about models and gradients. It must first answer a more basic question: where does the data come from?
In the simplest example, we can write tensors directly:
X = torch.randn(1000, 10)
y = torch.randn(1000, 1)Then manually slice them during training:
batch_size = 32
X_batch = X[:batch_size]
y_batch = y[:batch_size]This works, but several problems appear quickly. How should the data be shuffled? How should the last batch be handled if it has fewer than 32 samples? If the data is not a tensor but a set of image files, where should reading and preprocessing happen? If samples have different lengths and cannot be stacked into a regular tensor directly, how should a batch be formed? If reading and preprocessing are too slow, the GPU will wait for the CPU, and throughput will be limited by the data side.
These are not model problems. They are data pipeline problems.
PyTorch organizes this part with two concepts:
Datasetdefines what one sample looks like and how to retrieve one sample;DataLoaderorganizes many samples into mini-batches and handles common training details such as shuffling, multiprocessing, batch collation, and pin memory.
This section starts from simple tensor data and gradually explains the design of Dataset and DataLoader.
device = torch.accelerator.current_accelerator(check_available=True)
if device is None:
device = torch.device('cpu')
print('Using device:', device)2.4.1 Dataset: Access Samples Through a Unified Interface
In a training loop, we usually do not feed all data into the model at once. Instead, we retrieve one mini-batch at a time:
dataset -> mini-batch -> model -> loss -> backward -> optimizer step
If the data already exists as tensors, mini-batches can be implemented directly with slicing. The problem is that sample access, shuffling, batch construction, and handling the last batch all get mixed into the training loop.
Real data is often not one regular large tensor. In an image classification task, for example, the data may be a set of image paths:
cat_001.jpg -> label 0
dog_001.jpg -> label 1
cat_002.jpg -> label 0
...
At this point, a unified interface becomes more useful: whether the data comes from tensors, images, text files, or a database, it can be accessed in the same way:
sample = dataset[index]This is the role of Dataset.
In PyTorch, a basic map-style dataset only needs to answer two questions:
- How many samples are in the dataset?
- Given an index, how do we retrieve the corresponding sample?
In code, this means implementing __len__() and __getitem__().
We first write a minimal version:
class SimpleTensorDataset(utils.Dataset):
def __init__(self, X: Tensor, y: Tensor):
if len(X) != len(y):
raise ValueError('X and y must have the same length.')
self.X = X
self.y = y
def __len__(self) -> int:
return len(self.X)
def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
X = self.X[index]
y = self.y[index]
return X, yNow the tensors can be wrapped as a dataset:
X = torch.randn(1000, 10)
y = torch.randn(1000, 1)
dataset = SimpleTensorDataset(X, y)
x0, y0 = dataset[0]
print('Dataset length:', len(dataset))
print('First input shape:', x0.shape)
print('First target shape:', y0.shape)With this wrapper, the training code no longer needs to care how the data is stored internally. It only needs to know that samples can be retrieved by index.
PyTorch also provides a built-in version, TensorDataset, which packs any number of tensors into samples along the first dimension:
dataset = utils.TensorDataset(X, y)
x0, y0 = dataset[0]
print('First input shape:', x0.shape)
print('First target shape:', y0.shape)For ordinary tensor data, TensorDataset is already enough. Once the data becomes more complex, such as an image classification dataset or a text classification dataset, a custom Dataset is usually needed. In that case, __getitem__() implements file reading, preprocessing, and label returning.
At this point, the problem of accessing one sample is solved. During training, however, we rarely feed only one sample to the model. Samples are usually organized into mini-batches, and that part is handled by DataLoader.
2.4.2 DataLoader: Organize Samples into Mini-Batches
Dataset only knows how to retrieve one sample. DataLoader goes one step further and organizes samples into the mini-batches required by the training loop.
The simplest usage is:
dataloader = utils.DataLoader(
dataset,
batch_size=32,
shuffle=True,
)
X, y = next(iter(dataloader))
print('Input batch shape:', X.shape)
print('Target batch shape:', y.shape)Here DataLoader does four things:
- Retrieve several samples from the Dataset, controlled by
batch_size; - Combine these samples into a batch;
- If
shuffle=True, shuffle sample order at each epoch; - Return tensors that can be used directly for model training.
The training loop can therefore be written in the standard form:
for X, y in dataloader:
y_pred = model(X)
loss = loss_fn(y_pred, y)
loss.backward()From this angle, DataLoader is the adapter between the training loop and the dataset. The model does not need to know how the raw data is stored, and the training loop does not need to manage indices, shuffling, or batch construction by itself.
Their relationship can be summarized as:
Dataset: index -> sample
DataLoader: samples -> batch
This is why PyTorch separates Dataset and DataLoader. The former describes the data itself, while the latter describes how data is sent into the training process.
2.4.3 collate_fn: How Samples Are Combined into a Batch
DataLoader combines multiple samples into one batch. The combining process is controlled by collate_fn.
By default, PyTorch uses its default collation logic. For example, if each sample is (x, y), where x has shape (10,) and y has shape (1,), then 32 samples are automatically stacked into:
x: (32, 10)
y: (32, 1)
This is the behavior shown in the previous example.
If samples have different shapes, the default collation fails. Variable-length sequences in natural language processing are a typical example. Suppose there are 4 samples, each with a different length:
class VariableLengthDataset(utils.Dataset):
def __init__(self):
self.samples = [
torch.tensor([1, 2, 3]),
torch.tensor([4, 5]),
torch.tensor([6, 7, 8, 9]),
torch.tensor([10]),
]
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, index: int) -> Tensor:
return self.samples[index]If we use DataLoader directly, it tries to stack tensors of different lengths into a regular tensor. That is impossible:
dataset = VariableLengthDataset()
dataloader = utils.DataLoader(dataset, batch_size=2)
try:
batch = next(iter(dataloader))
except RuntimeError as err:
print('RuntimeError:', err)This is where a custom collate_fn is needed.
collate_fn receives a list containing the samples in the current batch. Its job is to arrange these samples into a batch that the model can consume. For example, variable-length sequences can be padded to the maximum length within the current batch:
def pad_collate_fn(batch: list[Tensor]) -> tuple[Tensor, Tensor]:
lengths = torch.tensor([len(x) for x in batch])
max_len = lengths.max().item()
padded = torch.zeros(len(batch), max_len, dtype=torch.long)
for i, x in enumerate(batch):
padded[i, : len(x)] = x
return padded, lengthsThen pass it to DataLoader:
dataloader = utils.DataLoader(
dataset,
batch_size=2,
collate_fn=pad_collate_fn,
)
for tokens, lengths in dataloader:
print(f'tokens:\n{tokens}')
print(f'lengths: {lengths}\n')PyTorch will then use pad_collate_fn to organize each batch into a tensor and a length vector.
In real tasks, collate_fn is common. For example:
- pad variable-length sentences in text tasks;
- handle object detection batches where each image has a different number of bounding boxes;
- organize images, text, masks, and metadata into a dictionary in multimodal tasks;
- apply extra arrangement to data within a batch.
Therefore, collate_fn can be understood as DataLoader’s packing rule. The default rule works for tensors with consistent shapes. When sample structures are more complex, we need to explicitly tell PyTorch how the batch should be assembled.
2.4.4 num_workers: Load Data and Compute the Model in Parallel
So far, all our DataLoaders load data in the main process. In other words, model training and data reading alternate in the same Python process:
read batch -> train one step -> read batch -> train one step -> ...
If each sample is only retrieved from in-memory tensors, this is usually not a problem. But if each sample requires reading images, decoding, data augmentation, or tokenization, CPU preprocessing may become the bottleneck. After the GPU finishes one batch, it has to wait for the CPU to prepare the next batch, and device utilization drops.
num_workers lets DataLoader start multiple subprocesses to load data ahead of time:
dataloader = utils.DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=0,
)When num_workers=0, all data loading happens in the main process. This is the simplest setting, the easiest to debug, and the least likely to produce unexpected multiprocessing issues. When num_workers>0, PyTorch starts multiple worker processes. They retrieve samples from the dataset ahead of time, run collate_fn, and put prepared batches into a queue. The main process can then train by taking ready batches from that queue:
worker 0 -> prepare batch
worker 1 -> prepare batch
worker 2 -> prepare batch
main process -> train model
This configuration allows data loading and model computation to overlap. Ideally, while the GPU is training on the current batch, CPU workers are already preparing later batches.
num_workers is not better simply because it is larger. It has several side effects.
First, more workers mean more processes and more memory overhead. Each worker holds a copy of the dataset, at least copying the Python objects inside it. If the dataset stores a very large Python list in __init__(), enabling multiple workers can noticeably increase memory usage.
Second, multiprocessing makes debugging harder. Errors inside Dataset.__getitem__() may not be displayed clearly in the main process. Sometimes the only visible message is similar to “DataLoader worker exited unexpectedly”. In that case, set num_workers back to 0 first so the error is exposed in the main process.
Finally, different operating systems start workers differently, which affects both code structure and performance.
On Linux, multiprocessing commonly uses fork. The child process is copied from the parent process with an almost identical state. Startup is fast, and many memory pages can temporarily be shared through copy-on-write. However, if workers modify certain objects, or if the dataset contains complex Python objects, memory usage may still gradually increase.
On Windows, there is no Unix-style fork, so multiprocessing usually uses spawn. A child process starts a new Python interpreter, re-imports the script, and receives required objects through serialization. This is safer, but startup is slower, and it requires the dataset, collate_fn, and related objects to be serializable.
Therefore, when using num_workers>0 on Windows, the training entry point usually needs to be placed under:
if __name__ == '__main__':
main()Otherwise, when the child process re-imports the script, it may repeatedly execute top-level code, recursively create processes, hang, or raise errors.
A cautious practice is to first use num_workers=0 to confirm that the code is correct; then gradually try num_workers=2, 4, 8; observe GPU utilization, CPU usage, and memory usage; and avoid increasing num_workers blindly.
The important point is that num_workers trades more CPU processes for a faster data supply. If the bottleneck is data reading and preprocessing, it is useful. If the data is already in memory and preprocessing is light, many workers may make the code slower instead.
Because of the special behavior of Jupyter Notebook, using num_workers>0 in a notebook environment may introduce extra problems. If you are debugging code in a notebook, avoid using num_workers>0.
2.4.5 persistent_workers: Avoid Restarting Workers Every Epoch
When num_workers>0, DataLoader starts multiple worker processes. By default, after one epoch finishes, these worker processes are shut down. At the start of the next epoch, they are created again.
If dataset initialization is light, this may not be noticeable. But if worker startup is expensive, such as re-importing many modules, initializing data reading state, opening file handles, or preparing caches, restarting workers at every epoch wastes time.
persistent_workers=True keeps workers alive after an epoch ends and reuses them in the next epoch:
dataloader = utils.DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=2,
persistent_workers=True,
)persistent_workers=True only matters when num_workers>0. If num_workers=0, there are no subprocesses to keep.
The benefit is reduced worker startup cost between epochs. This can be especially helpful on Windows, where spawn makes worker startup slower. The cost is that workers remain alive, so their memory and resources are also kept until the DataLoader is destroyed. If the dataset maintains internal state, note that this state is preserved across epochs instead of being reinitialized every epoch.
In short:
num_workerscontrols whether subprocesses are used;persistent_workerscontrols whether those subprocesses remain alive across epochs.
If training has many epochs and num_workers>0, it is worth trying persistent_workers=True. For debugging, small-data experiments, or memory-constrained situations, the default value is also perfectly fine.
2.4.6 pin_memory: Should Batches Be Placed in Page-Locked Memory?
For more details about pin_memory and non_blocking, as well as their measured effects in benchmarks, see PyTorch’s official document A guide on good usage of non_blocking and pin_memory in PyTorch. It analyzes CPU-to-GPU data transfer mechanisms and the performance impact of pin_memory and non_blocking in different scenarios.
When training on a GPU, each batch is usually prepared by the CPU first and then copied to the GPU:
X = X.to(device)
y = y.to(device)This performs a host-to-device copy, meaning a copy from CPU memory to GPU memory. pin_memory=True tells DataLoader to place returned tensors in pinned memory, also called page-locked memory.
Ordinary CPU memory may be managed through operating-system paging. Pinned memory will not be paged out, so the GPU can copy data from it more efficiently. If training frequently moves batches from CPU to GPU, pinned memory can reduce part of the overhead on this data transfer path.
Usage is simple:
dataloader = utils.DataLoader(
dataset,
batch_size=32,
shuffle=True,
pin_memory=device.type == 'cuda',
)Note that pin_memory is usually aimed at CUDA. Therefore, when training with CUDA, it is common to set:
pin_memory=TrueThen use non_blocking=True when moving data to the GPU:
X = X.to('cuda', non_blocking=True)
y = y.to('cuda', non_blocking=True)Here, non_blocking=True means that, if conditions allow, the CPU-to-GPU copy can run asynchronously and overlap with part of the computation. In practice, the source tensor usually needs to be in pinned memory for this to take effect.
Should pin_memory always be enabled? Not necessarily.
If you train with CUDA and batches are copied from CPU to GPU, pin_memory=True is usually worth trying, and many training scripts enable it by default. But if you train only on CPU, or the data is already on the GPU, or you use a non-CUDA accelerator such as Intel XPU, or batches are small and transfer is not the bottleneck, the benefit may be small and there may even be slight extra overhead.
A more practical rule is:
- CUDA training: try
pin_memory=Truefirst; - CPU training: usually unnecessary;
- uncertain whether it helps: measure actual training throughput.
pin_memory is not a correctness switch for the model. It is an efficiency switch for data transfer. The question it addresses is not whether training can run, but whether CPU-prepared batches can reach the GPU faster.
2.4.7 IterableDataset: When Data Cannot Be Randomly Accessed
So far, we have assumed the most common map-style dataset. The earlier SimpleTensorDataset is an example. Its key property is random access by index:
sample = dataset[index]Such datasets usually have a clear length and can be shuffled naturally. Image classification datasets, prepared text classification datasets, and ordinary tensor datasets are often written as map-style datasets.
Not all data is suitable for indexed access. Some data comes from streams that continuously produce samples:
- continuously reading new data from a logging system;
- streaming samples from a remote service or database;
- sequentially scanning a very large file;
- dynamically generating training samples in generative tasks.
In these cases, the total length may be unknown, and the i-th sample may not be randomly accessible. The only available operation is to keep iterating forward:
sample_1 -> sample_2 -> sample_3 -> ...
This kind of data is better represented by IterableDataset.
IterableDataset does not implement __getitem__(). Instead, it implements __iter__():
class SimpleCountingDataset(utils.IterableDataset):
def __init__(self, end: int):
self.end = end
def __iter__(self) -> Iterator[Tensor]:
for i in range(self.end):
yield torch.tensor(i)We can read it directly with DataLoader:
dataset = SimpleCountingDataset(end=10)
dataloader = utils.DataLoader(dataset, batch_size=4)
for i, batch in enumerate(dataloader, start=1):
print(f'Batch {i}: {batch}')The semantics have changed here. For a map-style dataset, DataLoader can generate a set of indices through a sampler and use those indices to retrieve samples. For an iterable-style dataset, DataLoader can only keep consuming the stream produced by __iter__().
Therefore, the core difference between Dataset and IterableDataset is not the code form, but the data access pattern:
| Type | Core interface | Data access pattern | Typical scenarios |
|---|---|---|---|
| Dataset | __len__() and __getitem__() |
Random access by index | Tensor data, image classification, fixed text datasets |
| IterableDataset | __iter__() |
Sequential iteration over a data stream | Streaming data, very large files, online sample generation |
There are also some usage differences.
First, shuffle=True is usually an operation for map-style datasets because it only needs to shuffle data indices. For IterableDataset, since the data is generated as a stream, DataLoader cannot know all samples in advance and shuffle them globally. If streaming data needs randomization, buffer shuffle or data sharding is usually implemented inside the IterableDataset.
Second, when IterableDataset is used with multiple workers, each worker receives a copy of the dataset. Without manual partitioning, different workers may read duplicate samples. For this case, we can use get_worker_info() inside __iter__() to identify the current worker and assign different ranges to different workers:
class ShardedCountingDataset(utils.IterableDataset):
def __init__(self, end: int):
self.end = end
def __iter__(self) -> Iterator[Tensor]:
worker_info = utils.get_worker_info()
if worker_info is None:
start = 0
step = 1
else:
start = worker_info.id
step = worker_info.num_workers
for i in range(start, self.end, step):
yield torch.tensor(i)In this example, with 2 workers, worker 0 produces 0, 2, 4, ..., and worker 1 produces 1, 3, 5, ..., avoiding duplicate reads of the same data.
In small experiments, map-style datasets are still used most of the time. The value of IterableDataset is that it shows PyTorch’s data pipeline is not only for small prepared datasets, but can also handle data streams closer to real systems.
For more engineering extensions around PyTorch data loading, see torchdata. One direction worth watching is StatefulDataLoader: it allows the data loader to save state like models and optimizers. When training is interrupted in the middle of an epoch, it may not have to restart from the beginning of the epoch, but can try to resume closer to the interrupted position.
2.4.8 A DataLoader Configuration Closer to a Training Script
At this point, the commonly confusing DataLoader parameters are mostly clear:
batch_sizecontrols how many samples are in each mini-batch;shufflecontrols whether sample order is shuffled at each epoch;collate_fncontrols how multiple samples are combined into a batch;num_workerscontrols whether multiple processes prefetch data;persistent_workerscontrols whether workers remain alive across epochs;pin_memorycontrols whether batches are placed in pinned memory to accelerate CPU-to-GPU copies.
In a common GPU training script, the DataLoader configuration can look like this:
This example uses the newer torch.accelerator API to dynamically detect the currently available accelerator device. If CUDA is available, it uses CUDA; if XPU is available, it uses XPU; if no accelerator is available, it falls back to CPU. Compared with directly writing torch.device('cuda' if torch.cuda.is_available() else 'cpu'), this method better covers different accelerator types. Later training code will continue to use this approach for device selection. This API was introduced in PyTorch 2.6.
device = torch.accelerator.current_accelerator(check_available=True)
if device is None:
device = torch.device('cpu')
dataset = utils.TensorDataset(X, y)
dataloader = utils.DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=2,
pin_memory=device.type == 'cuda',
persistent_workers=True,
)In the training loop:
for X, y in dataloader:
X = X.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
y_pred = model(X)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()This configuration is not optimal in every situation. On Windows, for example, num_workers=2 starts subprocesses with spawn, and the first startup may be slow. In notebook environments, multiprocessing is also more likely to run into serialization or entry-point problems. If the data is only small in-memory tensors, num_workers=0 may be simpler and faster.
Therefore, DataLoader parameters should not be memorized as a fixed template. They should be adjusted according to the bottleneck:
- Code is not running correctly yet: start with
num_workers=0; - Data reading is slow and the GPU waits for data: increase
num_workers; - CUDA training and CPU-to-GPU copy is visible: try
pin_memory=True; - Many epochs and slow worker startup: try
persistent_workers=True; - Samples cannot be collated by default: write a
collate_fn.
Under this interpretation, DataLoader is not just a tool for filling in batch_size. It is part of training throughput.
2.4.9 Summary
This section started from the data problem in the training loop and introduced PyTorch’s Dataset and DataLoader.
Dataset defines how a single sample is retrieved. The common map-style dataset supports indexed access through __len__() and __getitem__(). IterableDataset produces a data stream through __iter__() and is better suited to streaming data and sources that cannot be randomly accessed.
DataLoader turns samples into mini-batches. By default, it automatically stacks tensors with consistent shapes. If samples have different lengths or more complex structures, collate_fn is needed to customize the packing process.
On the efficiency side, num_workers can use multiple subprocesses to load data ahead of time, but it also introduces memory, debugging, and cross-platform issues. Windows usually uses spawn, so if __name__ == '__main__' and object serializability matter. Linux commonly uses fork, which starts faster, but complex state and memory usage still require care. pin_memory mainly serves CPU-to-GPU data copies in CUDA training, while persistent_workers reduces repeated worker startup cost across epochs.
The core of DataLoader is not a fixed configuration. It is to make data supply keep up with model training. Small experiments can start with simple settings; when data reading becomes the bottleneck, multiprocessing, pinned memory, and persistent workers can be enabled gradually.
At this point, the basic structure and common configuration of the PyTorch data pipeline are complete. The next section turns to model definition in PyTorch: nn.Module and nn.functional.