Skip to content

embeddings

embeddings provide a collection of pre-defined positional embeddings.

multimolecule.modules.embeddings

RotaryEmbedding

Bases: Module

Rotary position embeddings based on those in RoFormer.

Query and keys are transformed by rotation matrices which depend on their relative positions.

Cache

The inverse frequency buffer is cached and updated only when the sequence length changes or the device changes.

Sequence Length

Rotary Embedding is irrespective of the sequence length and can be used for any sequence length. Use the scale parameter to extend context length beyond training (e.g., scale=2.0 doubles effective context).

Example

embedding = RotaryEmbedding(embedding_dim=64) query, key = torch.randn(2, 4, 28, 64), torch.randn(2, 4, 28, 64) query, key = embedding(query, key) query.shape torch.Size([2, 4, 28, 64])

For extended context length

embedding_extended = RotaryEmbedding(embedding_dim=64, scale=2.0) embedding.state_dict() # no weight in state_dict OrderedDict()

Source code in multimolecule/modules/embeddings/rotary.py
Python
@POSITION_EMBEDDINGS.register("rotary")
@POSITION_EMBEDDINGS_HF.register("rotary")
class RotaryEmbedding(nn.Module):
    """
    Rotary position embeddings based on those in
    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer).

    Query and keys are transformed by rotation
    matrices which depend on their relative positions.

    Tip: **Cache**
        The inverse frequency buffer is cached and updated only when the sequence length changes or the device changes.

    Success: **Sequence Length**
        Rotary Embedding is irrespective of the sequence length and can be used for any sequence length.
        Use the `scale` parameter to extend context length beyond training (e.g., scale=2.0 doubles effective context).

    Example:
        >>> embedding = RotaryEmbedding(embedding_dim=64)
        >>> query, key = torch.randn(2, 4, 28, 64), torch.randn(2, 4, 28, 64)
        >>> query, key = embedding(query, key)
        >>> query.shape
        torch.Size([2, 4, 28, 64])
        >>> # For extended context length
        >>> embedding_extended = RotaryEmbedding(embedding_dim=64, scale=2.0)
        >>> embedding.state_dict()  # no weight in state_dict
        OrderedDict()
    """

    _seq_len_cached: int | None = None
    _cos_cached: Tensor | None = None
    _sin_cached: Tensor | None = None

    def __init__(
        self,
        embedding_dim: int,
        base: float = 10000.0,
        scale: float = 1.0,
        dtype: torch.dtype = torch.float32,
    ):
        """
        Initialize rotary position embeddings.

        Args:
            embedding_dim: Dimension of the embeddings (must be even)
            base: Base for computing inverse frequencies. Defaults to 10000.0.
            scale: Scaling factor for frequencies. Values > 1.0 extend context length
                   (e.g., scale=2.0 doubles the effective context). Defaults to 1.0.
            dtype: Data type for computations. Defaults to torch.float32.
        """
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, embedding_dim, 2, dtype=dtype) / embedding_dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.scale = scale

    def forward(self, q: Tensor, k: Tensor, offset: int = 0, seq_length: int | None = None) -> Tuple[Tensor, Tensor]:
        """
        Apply rotary position embeddings to query and key tensors.

        Args:
            q: Query tensor of shape `(batch_size, num_heads, seq_length, embedding_dim)`
            k: Key tensor of shape `(batch_size, num_heads, seq_length, embedding_dim)`
            offset: Position offset for the start of the sequence (used with past_key_values).
                    Defaults to 0.
            seq_length: Full sequence length including offset. If None, uses the sequence length
                    from the input tensors. Required when offset > 0.

        Returns:
            Tuple of (rotated_query, rotated_key) tensors with the same shapes as inputs.
        """
        if offset > 0 and seq_length is None:
            raise ValueError("seq_length must be provided when offset > 0")

        if seq_length is None:
            seq_length = k.shape[-2]

        self._update_cos_sin_tables(k, seq_len_dim=-2, seq_length=seq_length)
        return self.apply_rotary_pos_emb(q, offset=offset), self.apply_rotary_pos_emb(k, offset=offset)

    def _update_cos_sin_tables(
        self, x: Tensor, seq_len_dim: int = 2, seq_length: int | None = None
    ) -> Tuple[Tensor, Tensor]:
        """
        Update cached cos/sin tables for rotary embeddings.

        Args:
            x: Input tensor to determine device and dtype
            seq_len_dim: Dimension containing sequence length (default: -2)
            seq_length: Full sequence length to cache. If None, uses x.shape[seq_len_dim]
        """
        if seq_length is None:
            seq_length = x.shape[seq_len_dim]

        if seq_length != self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != x.device:
            self._seq_len_cached = seq_length
            inv_freq = self.inv_freq
            if not isinstance(inv_freq, Tensor):
                raise RuntimeError("inv_freq buffer is not a Tensor")
            t = torch.arange(seq_length, device=x.device, dtype=inv_freq.dtype)
            # Apply scaling: divide frequencies by scale to extend context length
            freqs = torch.outer(t, inv_freq) / self.scale
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self._cos_cached = emb.cos()[None, None, :, :]
            self._sin_cached = emb.sin()[None, None, :, :]
        # At this point, _cos_cached and _sin_cached are guaranteed to be Tensor
        assert self._cos_cached is not None and self._sin_cached is not None
        return self._cos_cached, self._sin_cached

    def apply_rotary_pos_emb(self, x: Tensor, offset: int = 0) -> Tensor:
        """
        Apply rotary position embeddings to a tensor.

        Args:
            x: Input tensor of shape `(batch_size, num_heads, seq_length, embedding_dim)`
            offset: Position offset for the start of the sequence (used with past_key_values).
                    Defaults to 0.

        Returns:
            Rotated tensor with the same shape as input.
        """
        if self._cos_cached is None or self._sin_cached is None:
            raise RuntimeError("Cos/sin tables not initialized. Call forward() or _update_cos_sin_tables() first.")

        cos = self._cos_cached[:, :, offset : offset + x.shape[-2], :]
        sin = self._sin_cached[:, :, offset : offset + x.shape[-2], :]
        return (x * cos) + (self.rotate_half(x) * sin)

    @staticmethod
    def rotate_half(x: Tensor) -> Tensor:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

__init__

Python
__init__(
    embedding_dim: int,
    base: float = 10000.0,
    scale: float = 1.0,
    dtype: dtype = float32,
)

Initialize rotary position embeddings.

Parameters:

Name Type Description Default
embedding_dim
int

Dimension of the embeddings (must be even)

required
base
float

Base for computing inverse frequencies. Defaults to 10000.0.

10000.0
scale
float

Scaling factor for frequencies. Values > 1.0 extend context length (e.g., scale=2.0 doubles the effective context). Defaults to 1.0.

1.0
dtype
dtype

Data type for computations. Defaults to torch.float32.

float32
Source code in multimolecule/modules/embeddings/rotary.py
Python
def __init__(
    self,
    embedding_dim: int,
    base: float = 10000.0,
    scale: float = 1.0,
    dtype: torch.dtype = torch.float32,
):
    """
    Initialize rotary position embeddings.

    Args:
        embedding_dim: Dimension of the embeddings (must be even)
        base: Base for computing inverse frequencies. Defaults to 10000.0.
        scale: Scaling factor for frequencies. Values > 1.0 extend context length
               (e.g., scale=2.0 doubles the effective context). Defaults to 1.0.
        dtype: Data type for computations. Defaults to torch.float32.
    """
    super().__init__()
    inv_freq = 1.0 / (base ** (torch.arange(0, embedding_dim, 2, dtype=dtype) / embedding_dim))
    self.register_buffer("inv_freq", inv_freq, persistent=False)
    self.scale = scale

forward

Python
forward(
    q: Tensor,
    k: Tensor,
    offset: int = 0,
    seq_length: int | None = None,
) -> Tuple[Tensor, Tensor]

Apply rotary position embeddings to query and key tensors.

Parameters:

Name Type Description Default
q
Tensor

Query tensor of shape (batch_size, num_heads, seq_length, embedding_dim)

required
k
Tensor

Key tensor of shape (batch_size, num_heads, seq_length, embedding_dim)

required
offset
int

Position offset for the start of the sequence (used with past_key_values). Defaults to 0.

0
seq_length
int | None

Full sequence length including offset. If None, uses the sequence length from the input tensors. Required when offset > 0.

None

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple of (rotated_query, rotated_key) tensors with the same shapes as inputs.

Source code in multimolecule/modules/embeddings/rotary.py
Python
def forward(self, q: Tensor, k: Tensor, offset: int = 0, seq_length: int | None = None) -> Tuple[Tensor, Tensor]:
    """
    Apply rotary position embeddings to query and key tensors.

    Args:
        q: Query tensor of shape `(batch_size, num_heads, seq_length, embedding_dim)`
        k: Key tensor of shape `(batch_size, num_heads, seq_length, embedding_dim)`
        offset: Position offset for the start of the sequence (used with past_key_values).
                Defaults to 0.
        seq_length: Full sequence length including offset. If None, uses the sequence length
                from the input tensors. Required when offset > 0.

    Returns:
        Tuple of (rotated_query, rotated_key) tensors with the same shapes as inputs.
    """
    if offset > 0 and seq_length is None:
        raise ValueError("seq_length must be provided when offset > 0")

    if seq_length is None:
        seq_length = k.shape[-2]

    self._update_cos_sin_tables(k, seq_len_dim=-2, seq_length=seq_length)
    return self.apply_rotary_pos_emb(q, offset=offset), self.apply_rotary_pos_emb(k, offset=offset)

apply_rotary_pos_emb

Python
apply_rotary_pos_emb(x: Tensor, offset: int = 0) -> Tensor

Apply rotary position embeddings to a tensor.

Parameters:

Name Type Description Default
x
Tensor

Input tensor of shape (batch_size, num_heads, seq_length, embedding_dim)

required
offset
int

Position offset for the start of the sequence (used with past_key_values). Defaults to 0.

0

Returns:

Type Description
Tensor

Rotated tensor with the same shape as input.

Source code in multimolecule/modules/embeddings/rotary.py
Python
def apply_rotary_pos_emb(self, x: Tensor, offset: int = 0) -> Tensor:
    """
    Apply rotary position embeddings to a tensor.

    Args:
        x: Input tensor of shape `(batch_size, num_heads, seq_length, embedding_dim)`
        offset: Position offset for the start of the sequence (used with past_key_values).
                Defaults to 0.

    Returns:
        Rotated tensor with the same shape as input.
    """
    if self._cos_cached is None or self._sin_cached is None:
        raise RuntimeError("Cos/sin tables not initialized. Call forward() or _update_cos_sin_tables() first.")

    cos = self._cos_cached[:, :, offset : offset + x.shape[-2], :]
    sin = self._sin_cached[:, :, offset : offset + x.shape[-2], :]
    return (x * cos) + (self.rotate_half(x) * sin)

SinusoidalEmbedding

Bases: Embedding

Sinusoidal positional embeddings for inputs with any length.

Freezing

The embeddings are frozen and cannot be trained. They will not be saved in the model’s state_dict.

Padding Idx

Padding symbols are ignored if the padding_idx is specified.

Sequence Length

These embeddings get automatically extended in forward if more positions is needed.

Parameters:

Name Type Description Default

num_embeddings

int

The number of embeddings to use.

required

embedding_dim

int

The dimension of the embeddings.

required

padding_idx

int | None

The index of the padding symbol.

None

bias

int

The bias of the embeddings.

1
Example

embedding = SinusoidalEmbedding(num_embeddings=128, embedding_dim=64) input_ids = torch.arange(28).repeat(4).view(4, -1) input_embeds = torch.randn(4, 28, 64) embeddings = embedding(input_ids) embeddings.shape # no batch dimension if padding_idx is None torch.Size([28, 64]) input_embeds = input_embeds + embeddings input_embeds.shape torch.Size([4, 28, 64]) embedding = SinusoidalEmbedding(num_embeddings=128, embedding_dim=64, padding_idx=0) embeddings = embedding(input_ids) embeddings.shape # batch dimension if padding_idx is not None torch.Size([4, 28, 64]) embedding.state_dict() # no weight in state_dict OrderedDict()

Source code in multimolecule/modules/embeddings/sinusoidal.py
Python
@POSITION_EMBEDDINGS.register("sinusoidal")
@POSITION_EMBEDDINGS_HF.register("sinusoidal")
class SinusoidalEmbedding(nn.Embedding):
    r"""
    Sinusoidal positional embeddings for inputs with any length.

    Note: **Freezing**
        The embeddings are frozen and cannot be trained.
        They will not be saved in the model's state_dict.

    Tip: **Padding Idx**
        Padding symbols are ignored if the padding_idx is specified.

    Success: **Sequence Length**
        These embeddings get automatically extended in forward if more positions is needed.

    Args:
        num_embeddings: The number of embeddings to use.
        embedding_dim: The dimension of the embeddings.
        padding_idx: The index of the padding symbol.
        bias: The bias of the embeddings.

    Example:
        >>> embedding = SinusoidalEmbedding(num_embeddings=128, embedding_dim=64)
        >>> input_ids = torch.arange(28).repeat(4).view(4, -1)
        >>> input_embeds = torch.randn(4, 28, 64)
        >>> embeddings = embedding(input_ids)
        >>> embeddings.shape  # no batch dimension if padding_idx is None
        torch.Size([28, 64])
        >>> input_embeds = input_embeds + embeddings
        >>> input_embeds.shape
        torch.Size([4, 28, 64])
        >>> embedding = SinusoidalEmbedding(num_embeddings=128, embedding_dim=64, padding_idx=0)
        >>> embeddings = embedding(input_ids)
        >>> embeddings.shape  # batch dimension if padding_idx is not None
        torch.Size([4, 28, 64])
        >>> embedding.state_dict()  # no weight in state_dict
        OrderedDict()
    """

    _is_hf_initialized = True

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: int | None = None,
        bias: int = 1,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
        **kwargs,
    ):
        weight = self.get_embedding(num_embeddings, embedding_dim, padding_idx, device=device, dtype=dtype)
        super().__init__(num_embeddings, embedding_dim, padding_idx, _weight=weight.detach(), _freeze=True, **kwargs)
        del self.weight
        self.register_buffer("weight", weight, persistent=False)
        self.bias = bias

    def update_weight(self, num_embeddings: int):
        weight = self.get_embedding(
            num_embeddings, self.embedding_dim, self.padding_idx, dtype=self.weight.dtype, device=self.weight.device
        )
        self.register_buffer("weight", weight, persistent=False)

    @staticmethod
    def get_embedding(
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: int | None = None,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
    ) -> Tensor:
        """
        Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
        "Attention Is All You Need".
        """
        if device is None:
            device = get_default_device()
        half_dim = embedding_dim // 2
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -(math.log(10000) / (half_dim - 1)))
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
        emb = emb.to(device=device, dtype=dtype)
        if embedding_dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1, dtype=dtype, device=device)], dim=1)
        if padding_idx is not None:
            emb[padding_idx, :] = 0
        return emb.detach()

    @staticmethod
    def get_position_ids(tensor: Tensor, padding_idx: int | None = None):
        """
        Replace non-padding symbols with their position numbers.

        Position numbers begin at padding_idx+1. Padding symbols are ignored.
        """
        # The series of casts and type-conversions here are carefully
        # balanced to both work with ONNX export and XLA. In particular XLA
        # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
        # how to handle the dtype kwarg in cumsum.
        if padding_idx is None:
            return torch.cumsum(tensor.new_ones(tensor.size(1), dtype=torch.long), dim=0) - 1
        mask = tensor.ne(padding_idx).long()
        return torch.cumsum(mask, dim=1, dtype=torch.long) * mask + padding_idx

    def forward(self, input_ids: Tensor) -> Tensor:
        _, seq_length = input_ids.shape[:2]
        # expand embeddings if needed
        max_position = seq_length + self.bias + 1
        if self.padding_idx is not None:
            max_position += self.padding_idx
        if max_position > self.weight.size(0):
            self.update_weight(max_position)
        # Need to shift the position ids by the padding index
        position_ids = self.get_position_ids(input_ids, self.padding_idx) + self.bias
        return super().forward(position_ids)

get_embedding staticmethod

Python
get_embedding(
    num_embeddings: int,
    embedding_dim: int,
    padding_idx: int | None = None,
    device: device | None = None,
    dtype: dtype = float32,
) -> Tensor

Build sinusoidal embeddings.

This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of “Attention Is All You Need”.

Source code in multimolecule/modules/embeddings/sinusoidal.py
Python
@staticmethod
def get_embedding(
    num_embeddings: int,
    embedding_dim: int,
    padding_idx: int | None = None,
    device: torch.device | None = None,
    dtype: torch.dtype = torch.float32,
) -> Tensor:
    """
    Build sinusoidal embeddings.

    This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
    "Attention Is All You Need".
    """
    if device is None:
        device = get_default_device()
    half_dim = embedding_dim // 2
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -(math.log(10000) / (half_dim - 1)))
    emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
    emb = emb.to(device=device, dtype=dtype)
    if embedding_dim % 2 == 1:
        emb = torch.cat([emb, torch.zeros(num_embeddings, 1, dtype=dtype, device=device)], dim=1)
    if padding_idx is not None:
        emb[padding_idx, :] = 0
    return emb.detach()

get_position_ids staticmethod

Python
get_position_ids(
    tensor: Tensor, padding_idx: int | None = None
)

Replace non-padding symbols with their position numbers.

Position numbers begin at padding_idx+1. Padding symbols are ignored.

Source code in multimolecule/modules/embeddings/sinusoidal.py
Python
@staticmethod
def get_position_ids(tensor: Tensor, padding_idx: int | None = None):
    """
    Replace non-padding symbols with their position numbers.

    Position numbers begin at padding_idx+1. Padding symbols are ignored.
    """
    # The series of casts and type-conversions here are carefully
    # balanced to both work with ONNX export and XLA. In particular XLA
    # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
    # how to handle the dtype kwarg in cumsum.
    if padding_idx is None:
        return torch.cumsum(tensor.new_ones(tensor.size(1), dtype=torch.long), dim=0) - 1
    mask = tensor.ne(padding_idx).long()
    return torch.cumsum(mask, dim=1, dtype=torch.long) * mask + padding_idx