Streaming Hugging Face Datasets in PyTorch

Hugging Face Logo

When working with massive datasets in machine learning, handling storage and processing efficiency is crucial. In this guide, we will explore how to use Hugging Face’s streaming datasets in PyTorch, enabling you to load only the data you need, without downloading everything. This guide will focus on using collabora/librilight-webdataset (4TB dataset) to demonstrate the streaming mode in Hugging Face’s datasets library and integrate it into PyTorch.

🛠️ Setting Up Streaming in Hugging Face

To load a streaming dataset, simply enable the streaming=True parameter when calling load_dataset. This avoids downloading the entire dataset, only loading data as needed.


from datasets import load_dataset

# Load the dataset in streaming mode
dataset = load_dataset('collabora/librilight-webdataset', split='train', streaming=True)
    

Note: Streaming datasets like this do not have a fixed length because they process data in real-time. Length-based operations may not be directly applicable.

📂 Creating a Custom PyTorch Dataset

We’ll create a PyTorch dataset class to wrap the streaming data. Since we know the dataset length, __len__ can return this value. The __getitem__ method retrieves items directly from the stream, without holding them in memory.


import torch
from torch.utils.data import Dataset
from datasets import load_dataset

class MyAudioDataset(Dataset):
    def __init__(self, dataset_name, split, dataset_length, max_length=10.0):
        self.dataset = load_dataset(dataset_name, split=split, streaming=True)
        self.dataset_length = dataset_length
        self.max_length = max_length  # max length in seconds

    def __len__(self):
        return self.dataset_length  # Approximate length known

    def __getitem__(self, idx):
        if idx >= self.dataset_length:
            raise IndexError("Index out of range")

        iterator = iter(self.dataset)
        for i, item in enumerate(iterator):
            if i == idx:
                audio = item['flac']['array']
                sample_rate = item['flac']['sampling_rate']

                # Limit audio to max_length in seconds
                max_samples = int(self.max_length * sample_rate)
                if len(audio) > max_samples:
                    start_idx = torch.randint(0, len(audio) - max_samples, (1,)).item()
                    audio = audio[start_idx:start_idx + max_samples]

                return {
                    "language": item['json']['book_meta']['language'],
                    "sample_rate": sample_rate,
                    "audio": audio
                }

        raise IndexError("Index out of range")
    

Warning: Iterable datasets do not inherently support length-based indexing due to their real-time processing nature.

⏩ Using DataLoader with Parallel Workers

To efficiently load data in parallel, PyTorch’s DataLoader supports the num_workers parameter. This allows multiple processes to retrieve data concurrently, ideal for large streaming datasets.


from torch.utils.data import DataLoader

dataset_length = 219041  # Known length
audio_dataset = MyAudioDataset('collabora/librilight-webdataset', 'train', dataset_length)

# DataLoader with parallel loading
dataloader = DataLoader(audio_dataset, batch_size=8, num_workers=4)

# Iterate over DataLoader
for batch in dataloader:
    print(batch)
    break  # Display only the first batch
    

🎙️ Bonus: Efficient Handling of Audio Data

When working with audio, setting a max_length helps standardize the length of audio samples. In our example, we use 10 seconds as the default. This ensures consistency, which is beneficial for model training.


class MyAudioDataset(Dataset):
    def __init__(self, dataset_name, split, dataset_length, max_length=10.0):
        self.dataset = load_dataset(dataset_name, split=split, streaming=True)
        self.dataset_length = dataset_length
        self.max_length = max_length  # Default 10 seconds

    def __getitem__(self, idx):
        iterator = iter(self.dataset)
        for i, item in enumerate(iterator):
            if i == idx:
                audio = item['flac']['array']
                sample_rate = item['flac']['sampling_rate']

                max_samples = int(self.max_length * sample_rate)
                if len(audio) > max_samples:
                    start_idx = torch.randint(0, len(audio) - max_samples, (1,)).item()
                    audio = audio[start_idx:start_idx + max_samples]

                return {
                    "language": item['json']['book_meta']['language'],
                    "sample_rate": sample_rate,
                    "audio": audio
                }
    

Tip: You can customize the returned dictionary to include additional metadata fields depending on your requirements.