Source code for torchdyno.data.datasets.seq_mnist

from typing import (
    Callable,
    Optional,
    Tuple,
)

import torch
from torchvision import transforms
from torchvision.datasets import MNIST


[docs] class SequentialMNIST(MNIST): """Sequential MNIST dataset. The Sequential MNIST dataset is a variant of the MNIST dataset where the pixels of the images are permuted in a fixed way. Each image is treated pixel by pixel as a sequence, resulting in the concatenation of the rows of the image. """
[docs] def __init__( self, root: str, train: bool = True, transform: Optional[Callable[..., torch.Tensor]] = None, target_transform: Optional[Callable[..., torch.Tensor]] = None, download: bool = False, permute_seed: Optional[int] = None, ): """Sequential MNIST dataset. Args: root (str, optional): root directory of dataset. train (bool, optional): whether to load the training or test set. Defaults to True. transform (Optional[Callable[..., torch.Tensor]], optional): a function/transform that takes in an PIL image and returns a transformed version. Defaults to None. target_transform (Optional[Callable[..., torch.Tensor]], optional): a function/transform that takes in the target and transforms it. Defaults to None. download (bool, optional): whether to download the dataset. Defaults to False. permute_seed (Optional[int], optional): seed for permutation. Defaults to None. """ if transform is None: transform = transforms.Compose([transforms.ToTensor()]) super().__init__(root, train, transform, target_transform, download) self.permute_seed = permute_seed
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: """Return the item at the given index.""" raw = super().__getitem__(index) img: torch.Tensor = raw[0] target: torch.Tensor = raw[1] if self.permute_seed is not None: img = img.view(-1)[ torch.randperm( img.numel(), generator=torch.Generator().manual_seed(self.permute_seed), ) ].view(img.size()) img = img.view(-1).unsqueeze(-1) return img, target