Skip to content

embeddings

embeddings provide a collection of pre-defined positional embeddings.

multimolecule.module.embeddings

RotaryEmbedding

Bases: Module

Rotary position embeddings based on those in RoFormer.

Query and keys are transformed by rotation matrices which depend on their relative positions.

Cache

The inverse frequency buffer is cached and updated only when the sequence length changes or the device changes.

Sequence Length

Rotary Embedding is irrespective of the sequence length and can be used for any sequence length.

Source code in multimolecule/module/embeddings/rotary.py
Python
@PositionEmbeddingRegistry.register("rotary")
@PositionEmbeddingRegistryHF.register("rotary")
class RotaryEmbedding(nn.Module):
    """
    Rotary position embeddings based on those in
    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer).

    Query and keys are transformed by rotation
    matrices which depend on their relative positions.

    Tip: **Cache**
        The inverse frequency buffer is cached and updated only when the sequence length changes or the device changes.

    Success: **Sequence Length**
        Rotary Embedding is irrespective of the sequence length and can be used for any sequence length.
    """

    def __init__(self, embedding_dim: int):
        super().__init__()
        # Generate and save the inverse frequency buffer (non trainable)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, embedding_dim, 2, dtype=torch.int64).float() / embedding_dim))
        self.register_buffer("inv_freq", inv_freq)

        self._seq_len_cached = None
        self._cos_cached = None
        self._sin_cached = None

    def forward(self, q: Tensor, k: Tensor) -> Tuple[Tensor, Tensor]:
        self._update_cos_sin_tables(k, seq_dimension=-2)

        return (self.apply_rotary_pos_emb(q), self.apply_rotary_pos_emb(k))

    def _update_cos_sin_tables(self, x, seq_dimension=2):
        seq_len = x.shape[seq_dimension]

        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
        if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
            self._seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
            freqs = torch.outer(t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

            self._cos_cached = emb.cos()[None, None, :, :]
            self._sin_cached = emb.sin()[None, None, :, :]

        return self._cos_cached, self._sin_cached

    def apply_rotary_pos_emb(self, x):
        cos = self._cos_cached[:, :, : x.shape[-2], :]
        sin = self._sin_cached[:, :, : x.shape[-2], :]

        return (x * cos) + (self.rotate_half(x) * sin)

    @staticmethod
    def rotate_half(x):
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

SinusoidalEmbedding

Bases: Embedding

Sinusoidal positional embeddings for inputs with any length.

Freezing

The embeddings are frozen and cannot be trained. They will not be saved in the model’s state_dict.

Padding Idx

Padding symbols are ignored if the padding_idx is specified.

Sequence Length

These embeddings get automatically extended in forward if more positions is needed.

Source code in multimolecule/module/embeddings/sinusoidal.py
Python
@PositionEmbeddingRegistry.register("sinusoidal")
@PositionEmbeddingRegistryHF.register("sinusoidal")
class SinusoidalEmbedding(nn.Embedding):
    r"""
    Sinusoidal positional embeddings for inputs with any length.

    Note: **Freezing**
        The embeddings are frozen and cannot be trained.
        They will not be saved in the model's state_dict.

    Tip: **Padding Idx**
        Padding symbols are ignored if the padding_idx is specified.

    Success: **Sequence Length**
        These embeddings get automatically extended in forward if more positions is needed.
    """

    _is_hf_initialized = True

    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None, bias: int = 0):
        weight = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
        super().__init__(num_embeddings, embedding_dim, padding_idx, _weight=weight.detach(), _freeze=True)
        self.bias = bias

    def update_weight(self, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None):
        weight = self.get_embedding(num_embeddings, embedding_dim, padding_idx).to(
            dtype=self.weight.dtype, device=self.weight.device  # type: ignore[has-type]
        )
        self.weight = nn.Parameter(weight.detach(), requires_grad=False)

    @staticmethod
    def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: int | None = None) -> Tensor:
        """
        Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
        "Attention Is All You Need".
        """
        half_dim = embedding_dim // 2
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -(math.log(10000) / (half_dim - 1)))
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
        if embedding_dim % 2 == 1:
            # zero pad
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
        if padding_idx is not None:
            emb[padding_idx, :] = 0
        return emb

    @staticmethod
    def get_position_ids(tensor, padding_idx: int | None = None):
        """
        Replace non-padding symbols with their position numbers.

        Position numbers begin at padding_idx+1. Padding symbols are ignored.
        """
        # The series of casts and type-conversions here are carefully
        # balanced to both work with ONNX export and XLA. In particular XLA
        # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
        # how to handle the dtype kwarg in cumsum.
        if padding_idx is None:
            return torch.cumsum(tensor.new_ones(tensor.size(1)).long(), dim=0) - 1
        mask = tensor.ne(padding_idx).int()
        return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx

    def forward(self, input_ids: Tensor) -> Tensor:
        _, seq_len = input_ids.shape[:2]
        # expand embeddings if needed
        max_pos = seq_len + self.bias + 1
        if self.padding_idx is not None:
            max_pos += self.padding_idx
        if max_pos > self.weight.size(0):
            self.update_weight(max_pos, self.embedding_dim, self.padding_idx)
        # Need to shift the position ids by the padding index
        position_ids = self.get_position_ids(input_ids, self.padding_idx) + self.bias
        return super().forward(position_ids)

    def state_dict(self, destination=None, prefix="", keep_vars=False):
        return {}

    def load_state_dict(self, *args, state_dict, strict=True):
        return

    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
        return

get_embedding staticmethod

Python
get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: int | None = None) -> Tensor

Build sinusoidal embeddings.

This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of “Attention Is All You Need”.

Source code in multimolecule/module/embeddings/sinusoidal.py
Python
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: int | None = None) -> Tensor:
    """
    Build sinusoidal embeddings.

    This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
    "Attention Is All You Need".
    """
    half_dim = embedding_dim // 2
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -(math.log(10000) / (half_dim - 1)))
    emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
    if embedding_dim % 2 == 1:
        # zero pad
        emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
    if padding_idx is not None:
        emb[padding_idx, :] = 0
    return emb

get_position_ids staticmethod

Python
get_position_ids(tensor, padding_idx: int | None = None)

Replace non-padding symbols with their position numbers.

Position numbers begin at padding_idx+1. Padding symbols are ignored.

Source code in multimolecule/module/embeddings/sinusoidal.py
Python
@staticmethod
def get_position_ids(tensor, padding_idx: int | None = None):
    """
    Replace non-padding symbols with their position numbers.

    Position numbers begin at padding_idx+1. Padding symbols are ignored.
    """
    # The series of casts and type-conversions here are carefully
    # balanced to both work with ONNX export and XLA. In particular XLA
    # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
    # how to handle the dtype kwarg in cumsum.
    if padding_idx is None:
        return torch.cumsum(tensor.new_ones(tensor.size(1)).long(), dim=0) - 1
    mask = tensor.ne(padding_idx).int()
    return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx