Source code for torchdyno.models.rnn_assembly.rnn_assembly

from typing import (
    Callable,
    List,
    Literal,
    Optional,
    Tuple,
    Union,
)

import numpy as np
import torch
import torch.nn.functional as F
from torch import (
    Tensor,
    nn,
)
from torch.utils.data import DataLoader

from torchdyno.models import initializers
from torchdyno.optim.ridge_regression import (
    fit_and_validate_readout,
    fit_readout,
)

from .block_diagonal import BlockDiagonal
from .skew_symm_coupling import (
    SkewAntisymmetricCoupling,
    get_coupling_indices,
)


[docs] class RNNAssembly(nn.Module):
[docs] def __init__( self, input_size: int, out_size: int, blocks: List[torch.Tensor], coupling_blocks: List[torch.Tensor], coupling_topology: List[Tuple[int, int]], eul_step: float = 1e-2, activation: str = "tanh", constrained_blocks: Optional[ Literal["fixed", "tanh", "clip", "orthogonal"] ] = None, dtype: torch.dtype = torch.float32, ): """Initializes the RNN of RNNs layer. Args: input_size (int): size of the input. out_size (int): size of the output. blocks (List[torch.Tensor]): list of blocks. coupling_blocks (List[torch.Tensor]): list of coupling blocks. coupling_topology (Union[int, float, List[Tuple[int, int]]]): coupling topology. eul_step (float, optional): Euler step. Defaults to 1e-2. activation (str, optional): activation function. Defaults to "tanh". constrained_blocks (Optional[Literal["fixed", "tanh", "clip", "orthogonal"]], optional): type of constraint. Defaults to None. dtype (torch.dtype, optional): data type. Defaults to torch.float32. """ super().__init__() self._input_size = input_size self._eul_step = eul_step self._activation = activation self._dtype = dtype self._blocks = BlockDiagonal( blocks=blocks, constrained=constrained_blocks, ) self._couplings = SkewAntisymmetricCoupling( block_sizes=self._blocks.block_sizes, coupling_blocks=coupling_blocks, coupling_topology=coupling_topology, ) self._input_mat = nn.Parameter( torch.normal( mean=0, std=1 / np.sqrt(self.hidden_size), size=(self._input_size, self.hidden_size), dtype=self._dtype, ), requires_grad=False, ) self._out_mat = nn.Parameter( torch.normal( mean=0, std=1 / np.sqrt(self.hidden_size), size=(self.hidden_size, out_size), dtype=self._dtype, ), ) self.activ_fn = getattr(torch, self._activation)
[docs] @staticmethod def from_initializers( input_size: int, out_size: int, block_sizes: List[int], block_init_fn: Union[str, Callable[[torch.Size], torch.Tensor]], coupling_block_init_fn: Union[str, Callable[[torch.Size], torch.Tensor]], coupling_topology: Union[int, float, List[Tuple[int, int]], Literal["ring"]], eul_step: float = 1e-2, activation: str = "tanh", constrained_blocks: Optional[ Literal["fixed", "tanh", "clip", "orthogonal"] ] = None, dtype: torch.dtype = torch.float32, ) -> "RNNAssembly": """Create an RNNAssembly from initializers. Args: input_size (int): size of the input. out_size (int): size of the output. block_sizes (List[int]): list of block sizes. block_init_fn (Union[str, Callable[[torch.Size], torch.Tensor]]): block initializer. coupling_block_init_fn (Union[str, Callable[[torch.Size], torch.Tensor]]): coupling block initializer. coupling_topology (Union[int, float, List[Tuple[int, int]], Literal["ring"]]): coupling topology. eul_step (float, optional): Euler step. Defaults to 1e-2. activation (str, optional): activation function. Defaults to "tanh". constrained_blocks (Optional[Literal["fixed", "tanh", "clip", "orthogonal"]], optional): type of constraint. Defaults to None. dtype (torch.dtype, optional): data type. Defaults to torch.float32. """ if isinstance(block_init_fn, str): block_init_fn_: Callable = getattr(initializers, block_init_fn) else: block_init_fn_ = block_init_fn if isinstance(coupling_block_init_fn, str): coupling_block_init_fn_ = getattr(initializers, coupling_block_init_fn) else: coupling_block_init_fn_ = coupling_block_init_fn blocks = [block_init_fn_((b_size, b_size), dtype) for b_size in block_sizes] coupling_indices = get_coupling_indices(block_sizes, coupling_topology) coupling_blocks = [ coupling_block_init_fn_((block_sizes[i], block_sizes[j]), dtype) for i, j in coupling_indices ] return RNNAssembly( input_size=input_size, out_size=out_size, blocks=blocks, coupling_blocks=coupling_blocks, coupling_topology=coupling_indices, eul_step=eul_step, activation=activation, constrained_blocks=constrained_blocks, dtype=dtype, )
[docs] def forward( self, input: torch.Tensor, initial_state: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if initial_state is None: initial_state = torch.zeros(self.hidden_size).to(self._input_mat) states = self.compute_states(input, initial_state, mask) output = states @ self._out_mat if self._dtype == torch.complex64: output = torch.abs(output) return output, states
[docs] def compute_states( self, input: torch.Tensor, initial_state: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: states = [] state = ( initial_state if initial_state is not None else torch.zeros(self.hidden_size, dtype=self._dtype).to(self._input_mat) ) timesteps = input.shape[0] for t in range(timesteps): state = state + self._eul_step * ( -state + self._blocks(self.activ_fn(state)) + self._couplings(state) + F.linear(input[t], self._input_mat) ) states.append(state if mask is None else mask * state) return torch.stack(states, dim=0)
[docs] def fit_readout( self, train_loader: DataLoader, l2_value: Union[float, List[float]] = 1e-9, washout: int = 0, score_fn: Optional[Callable[[Tensor, Tensor], float]] = None, mode: Optional[Literal["min", "max"]] = None, eval_on: Optional[Union[Literal["train"], DataLoader]] = None, ) -> Optional[float]: """Fit the readout layer. Args: train_loader (DataLoader): training data loader. l2_values (List[float]): list of L2 regularization values. score_fn (Callable[[Tensor, Tensor], float]): scoring function. washout (int, optional): the amount of timesteps to skip in the training dataset to prepare the internal state of the RNN. Defaults to 0. score_fn (Optional[Callable[[Tensor, Tensor], float]], optional): scoring function. Defaults to None. mode (Optional[Literal["min", "max"]], optional): whether to minimize or maximize the score. Defaults to None. eval_on (Optional[Union[Literal["train"], DataLoader]], optional): evaluation data. Defaults to None. Returns: Optional[float]: the best score. """ if eval_on: if score_fn is None: raise ValueError("Score function must be provided for validation.") if score_fn is not None and mode is None: raise ValueError("Mode must be provided for optimization.") if eval_on == "train": eval_loader = train_loader elif isinstance(eval_on, DataLoader): eval_loader = eval_on else: raise ValueError("Evaluation data must be provided as DataLoader.") if not isinstance(l2_value, list): l2_value = [l2_value] readout, best_l2, best_score = fit_and_validate_readout( train_loader=train_loader, eval_loader=eval_loader, l2_values=l2_value, preprocess_fn=self.compute_states, skip_first_n=washout, score_fn=score_fn, mode=mode, device=next(self.parameters()).device, ) else: readout = fit_readout( train_loader, preprocess_fn=self.compute_states, skip_first_n=washout, l2=l2_value, device=next(self.parameters()).device, ) self._out_mat.data = readout if eval_on: return best_score return None
@property def hidden_size(self) -> int: return self._blocks.layer_size