Source code for torchdyno.models.esn.reservoir

import logging
from copy import deepcopy
from typing import (
    Callable,
    List,
    Optional,
    Union,
)

import torch
import torch.nn.functional as F
from torch import (
    Generator,
    Size,
    Tensor,
)
from torch.nn import (
    Module,
    Parameter,
)

from torchdyno.models import initializers


[docs] class Reservoir(Module): """A Reservoir of for Echo State Networks. Args: input_size: the number of expected features in the input `x` hidden_size: the number of features in the hidden state `h` activation: name of the activation function from `torch` (e.g. `torch.tanh`) leakage: the value of the leaking parameter `alpha` input_scaling: the value for the desired scaling of the input (must be `<= 1`) rho: the desired spectral radius of the recurrent matrix (must be `< 1`) bias: if ``False``, the layer does not use bias weights `b` mode: execution mode of the reservoir (vanilla or intrinsic plasticity) kernel_initializer: the kind of initialization of the input transformation. Default: `'uniform'` recurrent_initializer: the kind of initialization of the recurrent matrix. Default: `'normal'` net_gain_and_bias: if ``True``, the network uses additional ``g`` (gain) and ``b`` (bias) parameters. Default: ``False`` """
[docs] def __init__( self, input_size: int, hidden_size: int, activation: str = "tanh", leakage: float = 1.0, input_scaling: float = 0.9, rho: float = 0.99, bias: bool = False, kernel_initializer: Union[str, Callable[[Size], Tensor]] = "uniform", recurrent_initializer: Union[str, Callable[[Size], Tensor]] = "uniform", net_gain_and_bias: bool = False, ): super().__init__() if input_scaling > 1: logging.warning("Input scaling must be <= 1") if rho >= 1: logging.warning("Spectral radius must be < 1") self.input_scaling = Parameter(torch.tensor(input_scaling), requires_grad=False) self.rho = Parameter(torch.tensor(rho), requires_grad=False) if isinstance(kernel_initializer, str): kernel_initializer_ = getattr(initializers, kernel_initializer) else: kernel_initializer_ = kernel_initializer if isinstance(recurrent_initializer, str): recurrent_initializer_ = getattr(initializers, recurrent_initializer) else: recurrent_initializer_ = recurrent_initializer self.W_in = Parameter( initializers.rescale( kernel_initializer_([hidden_size, input_size]), "linear", input_scaling ), requires_grad=False, ) self.W_hat = Parameter( initializers.rescale( recurrent_initializer_([hidden_size, hidden_size]), "spectral", rho ), requires_grad=False, ) self.b = ( Parameter( initializers.uniform([hidden_size], -input_scaling, input_scaling), requires_grad=False, ) if bias else None ) self.f = getattr(torch, activation) self.alpha = Parameter(torch.tensor(leakage), requires_grad=False) self.net_gain_and_bias = net_gain_and_bias if net_gain_and_bias: self.net_a = Parameter( initializers.ones((hidden_size,)), requires_grad=True ) self.net_b = Parameter( initializers.zeros((hidden_size,)), requires_grad=True ) self._aux_fwd_comp: Optional[Callable[..., Generator]] = None
[docs] @torch.no_grad() def forward( self, input: Tensor, initial_state: Optional[Tensor] = None, mask: Optional[Tensor] = None, ) -> Tensor: if initial_state is None: initial_state = torch.zeros(self.hidden_size).to(self.W_hat) _fwd_comp = ( self._state_comp if self._aux_fwd_comp is None else self._aux_fwd_comp ) embeddings = torch.stack( [state for state in _fwd_comp(input.to(self.W_hat), initial_state, mask)], # type: ignore[operator] dim=0, ) return embeddings
def _state_comp( self, input: Tensor, initial_state: Tensor, mask: Optional[Tensor] = None ): timesteps = input.shape[0] state = initial_state for t in range(timesteps): in_signal_t = F.linear( input[t].to(self.W_in), self.W_in, self.b ) + F.linear(state, self.W_hat) if self.net_gain_and_bias: in_signal_t = in_signal_t * self.net_a + self.net_b h_t = torch.tanh(in_signal_t) state = (1 - self.alpha) * state + self.alpha * h_t yield state if mask is None else mask * state
[docs] def merge_reservoirs( self, others: Union["Reservoir", List["Reservoir"]], joint_scaling: Optional[float] = None, coupled: bool = False, independent_inputs: bool = False, ) -> "Reservoir": """Merges two reservoirs into a single reservoir.""" if self.net_gain_and_bias: raise ValueError("Cannot merge reservoirs with net gain and bias") if not isinstance(others, list): others = [others] new_reservoir = deepcopy(self) if independent_inputs: new_insize = sum([other.input_size for other in others]) + self.input_size else: new_insize = self.input_size new_hsize = sum([other.hidden_size for other in others]) + self.hidden_size W_in = torch.zeros(new_hsize, new_insize) if independent_inputs: W_in[: self.hidden_size, : self.input_size] = self.W_in.data curr_inoffset = self.input_size curr_hoffset = self.hidden_size for other in others: W_in[ curr_hoffset : curr_hoffset + other.hidden_size, curr_inoffset : curr_inoffset + other.input_size, ] = other.W_in.data curr_hoffset += other.hidden_size curr_inoffset += other.input_size else: W_in[: self.hidden_size] = self.W_in.data curr_offset = self.hidden_size for other in others: W_in[curr_offset : curr_offset + other.hidden_size] = other.W_in.data curr_offset += other.hidden_size if coupled: W_hat = torch.empty(new_hsize, new_hsize).uniform_(-1, 1) else: W_hat = torch.zeros(new_hsize, new_hsize) W_hat[: self.hidden_size, : self.hidden_size] = self.W_hat.data curr_offset = self.hidden_size for other in others: W_hat[ curr_offset : curr_offset + other.hidden_size, curr_offset : curr_offset + other.hidden_size, ] = other.W_hat.data curr_offset += other.hidden_size if coupled: W_hat = initializers.rescale( W_hat, "spectral", joint_scaling if joint_scaling is not None else self.rho.data, ) new_reservoir.b = None if any([other.b is not None for other in others]) or self.b is not None: new_reservoir.b = Parameter(torch.zeros(new_hsize)) if self.b is not None: new_reservoir.b.data[: self.hidden_size] = self.b.data curr_offset = self.hidden_size for other in others: if other.b is not None: new_reservoir.b.data[ curr_offset : curr_offset + other.hidden_size ] = other.b.data curr_offset += other.hidden_size new_reservoir.W_in.data = W_in new_reservoir.W_hat.data = W_hat return new_reservoir
@property def input_size(self) -> int: """Input dimension.""" return self.W_in.shape[1] @property def hidden_size(self) -> int: """Reservoir state dimension.""" return self.W_hat.shape[1]