Source code for torchdyno.models.rnn_assembly.block_diagonal
from typing import (
List,
Literal,
Optional,
Union,
)
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torchdyno.models.initializers import block_diagonal
[docs]
class BlockDiagonal(nn.Module):
[docs]
def __init__(
self,
blocks: List[torch.Tensor],
bias: bool = False,
constrained: Optional[Literal["fixed", "tanh", "clip", "orthogonal"]] = None,
):
"""Initializes the block diagonal matrix.
Args:
blocks (List[torch.Tensor]): list of blocks.
bias (bool, optional): whether to use bias. Defaults to False.
constrained (Optional[Literal["fixed", "tanh", "clip", "orthogonal"]], optional):
type of constraint. Defaults to None.
"""
super().__init__()
self._block_sizes = [block.size(0) for block in blocks]
self._constrained = constrained
self._blocks = nn.Parameter(
block_diagonal(blocks),
requires_grad=constrained != "fixed",
)
self._blocks_mask = nn.Parameter(self._blocks != 0, requires_grad=False)
self._support_eye = torch.eye(self.layer_size)
if bias:
self.bias = nn.Parameter(
torch.normal(
mean=0, std=(1 / np.sqrt(self.layer_size)), dtype=self._blocks.dtype
),
)
else:
self.bias = None
self._cached_blocks = None
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.blocks, self.bias)
@property
def n_blocks(self) -> int:
return len(self._block_sizes)
@property
def block_sizes(self) -> List[int]:
return self._block_sizes
@property
def layer_size(self) -> int:
return sum(self._block_sizes)
@property
def blocks(self) -> Union[torch.Tensor, List[torch.Tensor]]:
if self._cached_blocks is None or (
self.training and self._constrained != "fixed"
):
if self._constrained != "orthogonal":
blocks_ = self._blocks * self._blocks_mask
if self._constrained == "tanh":
blocks_ = torch.tanh(blocks_)
elif self._constrained == "clip":
blocks_ = torch.clamp(blocks_, -0.999, 0.999)
else:
symm = self._blocks - self._blocks.T
blocks_ = self._support_eye + symm + (symm @ symm) / 2
self._cached_blocks = blocks_
return self._cached_blocks