When working with massive datasets in machine learning, handling storage and processing efficiency is crucial. In this guide, we will explore how to use Hugging Face’s streaming datasets in PyTorch, enabling you to load only the data you need, without downloading everything.
This guide will focus on using collabora/librilight-webdataset
(4TB dataset) to demonstrate the streaming mode in Hugging Face’s datasets
library and integrate it into PyTorch.
To load a streaming dataset, simply enable the streaming=True
parameter when calling load_dataset
. This avoids downloading the entire dataset, only loading data as needed.
from datasets import load_dataset
# Load the dataset in streaming mode
dataset = load_dataset('collabora/librilight-webdataset', split='train', streaming=True)
Note: Streaming datasets like this do not have a fixed length because they process data in real-time. Length-based operations may not be directly applicable.
We’ll create a PyTorch dataset class to wrap the streaming data. Since we know the dataset length, __len__
can return this value. The __getitem__
method retrieves items directly from the stream, without holding them in memory.
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
class MyAudioDataset(Dataset):
def __init__(self, dataset_name, split, dataset_length, max_length=10.0):
self.dataset = load_dataset(dataset_name, split=split, streaming=True)
self.dataset_length = dataset_length
self.max_length = max_length # max length in seconds
def __len__(self):
return self.dataset_length # Approximate length known
def __getitem__(self, idx):
if idx >= self.dataset_length:
raise IndexError("Index out of range")
iterator = iter(self.dataset)
for i, item in enumerate(iterator):
if i == idx:
audio = item['flac']['array']
sample_rate = item['flac']['sampling_rate']
# Limit audio to max_length in seconds
max_samples = int(self.max_length * sample_rate)
if len(audio) > max_samples:
start_idx = torch.randint(0, len(audio) - max_samples, (1,)).item()
audio = audio[start_idx:start_idx + max_samples]
return {
"language": item['json']['book_meta']['language'],
"sample_rate": sample_rate,
"audio": audio
}
raise IndexError("Index out of range")
Warning: Iterable datasets do not inherently support length-based indexing due to their real-time processing nature.
To efficiently load data in parallel, PyTorch’s DataLoader
supports the num_workers
parameter. This allows multiple processes to retrieve data concurrently, ideal for large streaming datasets.
from torch.utils.data import DataLoader
dataset_length = 219041 # Known length
audio_dataset = MyAudioDataset('collabora/librilight-webdataset', 'train', dataset_length)
# DataLoader with parallel loading
dataloader = DataLoader(audio_dataset, batch_size=8, num_workers=4)
# Iterate over DataLoader
for batch in dataloader:
print(batch)
break # Display only the first batch
When working with audio, setting a max_length
helps standardize the length of audio samples. In our example, we use 10 seconds as the default. This ensures consistency, which is beneficial for model training.
class MyAudioDataset(Dataset):
def __init__(self, dataset_name, split, dataset_length, max_length=10.0):
self.dataset = load_dataset(dataset_name, split=split, streaming=True)
self.dataset_length = dataset_length
self.max_length = max_length # Default 10 seconds
def __getitem__(self, idx):
iterator = iter(self.dataset)
for i, item in enumerate(iterator):
if i == idx:
audio = item['flac']['array']
sample_rate = item['flac']['sampling_rate']
max_samples = int(self.max_length * sample_rate)
if len(audio) > max_samples:
start_idx = torch.randint(0, len(audio) - max_samples, (1,)).item()
audio = audio[start_idx:start_idx + max_samples]
return {
"language": item['json']['book_meta']['language'],
"sample_rate": sample_rate,
"audio": audio
}
Tip: You can customize the returned dictionary to include additional metadata fields depending on your requirements.