Skip to content

embeddings

embeddings provide a collection of pre-defined positional embeddings.

multimolecule.modules.embeddings

RotaryEmbedding

Bases: Module

Rotary position embeddings based on those in RoFormer.

Query and keys are transformed by rotation matrices which depend on their relative positions.

Cache

The inverse frequency buffer is cached and updated only when the sequence length changes or the device changes.

Sequence Length

Rotary Embedding is irrespective of the sequence length and can be used for any sequence length.

Example

embedding = RotaryEmbedding(embedding_dim=64) query, key = torch.randn(2, 4, 28, 64), torch.randn(2, 4, 28, 64) query, key = embedding(query, key) query.shape torch.Size([2, 4, 28, 64]) embedding.state_dict() # no weight in state_dict OrderedDict()

Source code in multimolecule/modules/embeddings/rotary.py
Python
@POSITION_EMBEDDINGS.register("rotary")
@POSITION_EMBEDDINGS_HF.register("rotary")
class RotaryEmbedding(nn.Module):
    """
    Rotary position embeddings based on those in
    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer).

    Query and keys are transformed by rotation
    matrices which depend on their relative positions.

    Tip: **Cache**
        The inverse frequency buffer is cached and updated only when the sequence length changes or the device changes.

    Success: **Sequence Length**
        Rotary Embedding is irrespective of the sequence length and can be used for any sequence length.

    Example:
        >>> embedding = RotaryEmbedding(embedding_dim=64)
        >>> query, key = torch.randn(2, 4, 28, 64), torch.randn(2, 4, 28, 64)
        >>> query, key = embedding(query, key)
        >>> query.shape
        torch.Size([2, 4, 28, 64])
        >>> embedding.state_dict()  # no weight in state_dict
        OrderedDict()
    """

    _seq_len_cached: int | None = None
    _cos_cached: Tensor = None
    _sin_cached: Tensor = None

    def __init__(self, embedding_dim: int, base: float = 10000.0, dtype: torch.dtype = torch.float32):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, embedding_dim, 2, dtype=dtype) / embedding_dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, q: Tensor, k: Tensor) -> Tuple[Tensor, Tensor]:
        self._update_cos_sin_tables(k, seq_len_dim=-2)
        return self.apply_rotary_pos_emb(q), self.apply_rotary_pos_emb(k)

    def _update_cos_sin_tables(self, x: Tensor, seq_len_dim: int = 2) -> Tuple[Tensor, Tensor]:
        seq_length = x.shape[seq_len_dim]
        if seq_length != self._seq_len_cached or self._cos_cached.device != x.device:
            self._seq_len_cached = seq_length
            t = torch.arange(x.shape[seq_len_dim], device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.outer(t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self._cos_cached = emb.cos()[None, None, :, :]
            self._sin_cached = emb.sin()[None, None, :, :]
        return self._cos_cached, self._sin_cached

    def apply_rotary_pos_emb(self, x: Tensor) -> Tensor:
        cos = self._cos_cached[:, :, : x.shape[-2], :]
        sin = self._sin_cached[:, :, : x.shape[-2], :]
        return (x * cos) + (self.rotate_half(x) * sin)

    @staticmethod
    def rotate_half(x):
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

SinusoidalEmbedding

Bases: Embedding

Sinusoidal positional embeddings for inputs with any length.

Freezing

The embeddings are frozen and cannot be trained. They will not be saved in the model’s state_dict.

Padding Idx

Padding symbols are ignored if the padding_idx is specified.

Sequence Length

These embeddings get automatically extended in forward if more positions is needed.

Parameters:

Name Type Description Default

num_embeddings

int

The number of embeddings to use.

required

embedding_dim

int

The dimension of the embeddings.

required

padding_idx

int | None

The index of the padding symbol.

None

bias

int

The bias of the embeddings.

1
Example

embedding = SinusoidalEmbedding(num_embeddings=128, embedding_dim=64) input_ids = torch.arange(28).repeat(4).view(4, -1) input_embeds = torch.randn(4, 28, 64) embeddings = embedding(input_ids) embeddings.shape # no batch dimension if padding_idx is None torch.Size([28, 64]) input_embeds = input_embeds + embeddings input_embeds.shape torch.Size([4, 28, 64]) embedding = SinusoidalEmbedding(num_embeddings=128, embedding_dim=64, padding_idx=0) embeddings = embedding(input_ids) embeddings.shape # batch dimension if padding_idx is not None torch.Size([4, 28, 64]) embedding.state_dict() # no weight in state_dict OrderedDict()

Source code in multimolecule/modules/embeddings/sinusoidal.py
Python
@POSITION_EMBEDDINGS.register("sinusoidal")
@POSITION_EMBEDDINGS_HF.register("sinusoidal")
class SinusoidalEmbedding(nn.Embedding):
    r"""
    Sinusoidal positional embeddings for inputs with any length.

    Note: **Freezing**
        The embeddings are frozen and cannot be trained.
        They will not be saved in the model's state_dict.

    Tip: **Padding Idx**
        Padding symbols are ignored if the padding_idx is specified.

    Success: **Sequence Length**
        These embeddings get automatically extended in forward if more positions is needed.

    Args:
        num_embeddings: The number of embeddings to use.
        embedding_dim: The dimension of the embeddings.
        padding_idx: The index of the padding symbol.
        bias: The bias of the embeddings.

    Example:
        >>> embedding = SinusoidalEmbedding(num_embeddings=128, embedding_dim=64)
        >>> input_ids = torch.arange(28).repeat(4).view(4, -1)
        >>> input_embeds = torch.randn(4, 28, 64)
        >>> embeddings = embedding(input_ids)
        >>> embeddings.shape  # no batch dimension if padding_idx is None
        torch.Size([28, 64])
        >>> input_embeds = input_embeds + embeddings
        >>> input_embeds.shape
        torch.Size([4, 28, 64])
        >>> embedding = SinusoidalEmbedding(num_embeddings=128, embedding_dim=64, padding_idx=0)
        >>> embeddings = embedding(input_ids)
        >>> embeddings.shape  # batch dimension if padding_idx is not None
        torch.Size([4, 28, 64])
        >>> embedding.state_dict()  # no weight in state_dict
        OrderedDict()
    """

    _is_hf_initialized = True

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: int | None = None,
        bias: int = 1,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
        **kwargs,
    ):
        weight = self.get_embedding(num_embeddings, embedding_dim, padding_idx, device=device, dtype=dtype)
        super().__init__(num_embeddings, embedding_dim, padding_idx, _weight=weight.detach(), _freeze=True, **kwargs)
        del self.weight
        self.register_buffer("weight", weight, persistent=False)
        self.bias = bias

    def update_weight(self, num_embeddings: int):
        weight = self.get_embedding(
            num_embeddings, self.embedding_dim, self.padding_idx, dtype=self.weight.dtype, device=self.weight.device
        )
        self.register_buffer("weight", weight, persistent=False)

    @staticmethod
    def get_embedding(
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: int | None = None,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
    ) -> Tensor:
        """
        Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
        "Attention Is All You Need".
        """
        if device is None:
            device = torch.get_default_device()
        half_dim = embedding_dim // 2
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -(math.log(10000) / (half_dim - 1)))
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
        emb = emb.to(device=device, dtype=dtype)
        if embedding_dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1, dtype=dtype, device=device)], dim=1)
        if padding_idx is not None:
            emb[padding_idx, :] = 0
        return emb.detach()

    @staticmethod
    def get_position_ids(tensor: Tensor, padding_idx: int | None = None):
        """
        Replace non-padding symbols with their position numbers.

        Position numbers begin at padding_idx+1. Padding symbols are ignored.
        """
        # The series of casts and type-conversions here are carefully
        # balanced to both work with ONNX export and XLA. In particular XLA
        # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
        # how to handle the dtype kwarg in cumsum.
        if padding_idx is None:
            return torch.cumsum(tensor.new_ones(tensor.size(1), dtype=torch.long), dim=0) - 1
        mask = tensor.ne(padding_idx).long()
        return torch.cumsum(mask, dim=1, dtype=torch.long) * mask + padding_idx

    def forward(self, input_ids: Tensor) -> Tensor:
        _, seq_length = input_ids.shape[:2]
        # expand embeddings if needed
        max_position = seq_length + self.bias + 1
        if self.padding_idx is not None:
            max_position += self.padding_idx
        if max_position > self.weight.size(0):
            self.update_weight(max_position)
        # Need to shift the position ids by the padding index
        position_ids = self.get_position_ids(input_ids, self.padding_idx) + self.bias
        return super().forward(position_ids)

get_embedding staticmethod

Python
get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: int | None = None, device: device | None = None, dtype: dtype = float32) -> Tensor

Build sinusoidal embeddings.

This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of “Attention Is All You Need”.

Source code in multimolecule/modules/embeddings/sinusoidal.py
Python
@staticmethod
def get_embedding(
    num_embeddings: int,
    embedding_dim: int,
    padding_idx: int | None = None,
    device: torch.device | None = None,
    dtype: torch.dtype = torch.float32,
) -> Tensor:
    """
    Build sinusoidal embeddings.

    This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
    "Attention Is All You Need".
    """
    if device is None:
        device = torch.get_default_device()
    half_dim = embedding_dim // 2
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -(math.log(10000) / (half_dim - 1)))
    emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
    emb = emb.to(device=device, dtype=dtype)
    if embedding_dim % 2 == 1:
        emb = torch.cat([emb, torch.zeros(num_embeddings, 1, dtype=dtype, device=device)], dim=1)
    if padding_idx is not None:
        emb[padding_idx, :] = 0
    return emb.detach()

get_position_ids staticmethod

Python
get_position_ids(tensor: Tensor, padding_idx: int | None = None)

Replace non-padding symbols with their position numbers.

Position numbers begin at padding_idx+1. Padding symbols are ignored.

Source code in multimolecule/modules/embeddings/sinusoidal.py
Python
@staticmethod
def get_position_ids(tensor: Tensor, padding_idx: int | None = None):
    """
    Replace non-padding symbols with their position numbers.

    Position numbers begin at padding_idx+1. Padding symbols are ignored.
    """
    # The series of casts and type-conversions here are carefully
    # balanced to both work with ONNX export and XLA. In particular XLA
    # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
    # how to handle the dtype kwarg in cumsum.
    if padding_idx is None:
        return torch.cumsum(tensor.new_ones(tensor.size(1), dtype=torch.long), dim=0) - 1
    mask = tensor.ne(padding_idx).long()
    return torch.cumsum(mask, dim=1, dtype=torch.long) * mask + padding_idx