Source code for torchdyno.models.rnn_assembly.skew_symm_coupling
import random
from typing import (
List,
Literal,
Tuple,
Union,
)
import torch
import torch.nn.functional as F
from torch import nn
from torchdyno.models.initializers import block_diagonal_coupling
[docs]
class SkewAntisymmetricCoupling(nn.Module):
[docs]
def __init__(
self,
block_sizes: List[int],
coupling_blocks: List[torch.Tensor],
coupling_topology: List[Tuple[int, int]],
):
"""Initializes the skew antisymmetric coupling layer.
Args:
block_sizes (List[int]): list of block sizes.
coupling_blocks (List[torch.Tensor]): list of coupling blocks.
coupling_topology (List[Tuple[int, int]]): list of coupling topology.
"""
super().__init__()
self._block_sizes = block_sizes
self._coupling_topology = coupling_topology
if len(coupling_blocks) != len(coupling_topology):
raise ValueError(
"The number of coupling blocks must be equal to the number of coupling topologies."
)
self._couplings = nn.Parameter(
torch.tensor(
block_diagonal_coupling(
block_sizes,
[
(i, j, coupling_blocks[idx])
for idx, (i, j) in enumerate(coupling_topology)
],
)
),
)
self._couple_mask = nn.Parameter(self._couplings != 0, requires_grad=False)
self._cached_coupling = None
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.couplings)
@property
def couplings(self) -> torch.Tensor:
if self._cached_coupling is None or self.training:
couple_masked: torch.Tensor = self._couple_mask * self._couplings
self._cached_coupling = couple_masked - couple_masked.T
return self._cached_coupling
[docs]
def get_coupling_indices(
block_sizes: List[int],
coupling_topology: Union[int, float, Literal["ring"], List[Tuple[int, int]]],
) -> List[Tuple[int, int]]:
"""Returns the coupling indices based on the topology.
Args:
block_sizes (List[int]): list of block sizes.
coupling_topology (Union[int, float, Literal["ring"]]): coupling topology.
Returns:
List[Tuple[int, int]]: list of coupling indices.
"""
if isinstance(coupling_topology, (int, float)):
coupling_indices = [
(i, j)
for i in range(len(block_sizes) - 1)
for j in range(i + 1, len(block_sizes))
]
if coupling_topology > 0 and coupling_topology <= 1:
coupling_topology = int(coupling_topology * len(coupling_indices))
coupling_indices = random.sample(
coupling_indices, int(min(len(coupling_indices), coupling_topology))
)
elif coupling_topology == "ring":
coupling_indices = [
(i, (i + 1) % len(block_sizes)) for i in range(len(block_sizes))
]
else:
coupling_indices = coupling_topology
return coupling_indices