Source code for torchdyno.data.datasets.memory_capacity
import torch
from torch.utils.data import Dataset
[docs]
class MemoryCapacityDataset(Dataset):
"""Memory capacity dataset.
The memory capacity dataset is a simple dataset that is used to evaluate the memory
capacity of recurrent neural networks. The dataset is generated by sampling random
numbers from a uniform distribution and is used to predict the next number in the
sequence.
"""
TRAIN_SIZE = 5000
TEST_SIZE = 1000
[docs]
def __init__(
self,
delay: int,
length: int = 6000,
seed: int = 0,
return_full_sequence: bool = False,
):
"""Memory capacity dataset.
Args:
delay: The delay between the input and the target. Ideally, 2*hidden_size of
the evaluated RNN.
length: The length of the dataset.
seed: Random seed.
"""
self.length = length
self.delay = delay
self.seed = seed
self.return_full_sequence = return_full_sequence
self.data = self._generate_data()
def __len__(self):
"""Return the length of the dataset."""
if self.return_full_sequence:
return 1
return self.length - self.delay
def __getitem__(self, idx: int):
"""Return the item at the given index.
If return_full_sequence is True, the index is ignored.
"""
if self.return_full_sequence:
target = []
for i in range(self.delay + 1):
target.append(self.data[:-i])
return self.data[self.delay :], torch.stack(
[self.data[:-i] for i in range(self.delay + 1)]
)
return self.data[idx + self.delay], self.data[idx]
def _generate_data(self):
"""Generate the data for the memory capacity dataset."""
torch.manual_seed(self.seed)
data = torch.empty(self.length).uniform_(-0.8, 0.8)
return data