Source code for torchdyno.data.utils.seq_loader

import torch


[docs] def seq_collate_fn(scenario: str = "stationary"): if scenario == "stationary": def _collate_fn(batch): x, y = [], [] for x_i, y_i in batch: x.append(x_i) y.append(y_i) if isinstance(y[0], torch.Tensor): return torch.stack(x, dim=1), torch.stack(y, dim=1) else: return torch.stack(x, dim=1), torch.tensor(y) elif scenario == "continual": def _collate_fn(batch): x, y = [], [] for x_i, y_i, _ in batch: x.append(x_i) y.append(y_i) return torch.stack(x, dim=1), torch.stack(y, dim=1) else: raise ValueError(f"Unknown scenario for collate_fn: {scenario[:10]}") return _collate_fn