embeddings¶
embeddings provide a collection of pre-defined positional embeddings.
multimolecule.modules.embeddings
¶
RotaryEmbedding
¶
Bases: Module
Rotary position embeddings based on those in RoFormer.
Query and keys are transformed by rotation matrices which depend on their relative positions.
Cache
The inverse frequency buffer is cached and updated only when the sequence length changes or the device changes.
Sequence Length
Rotary Embedding is irrespective of the sequence length and can be used for any sequence length.
Use the scale parameter to extend context length beyond training (e.g., scale=2.0 doubles effective context).
Example
embedding = RotaryEmbedding(embedding_dim=64) query, key = torch.randn(2, 4, 28, 64), torch.randn(2, 4, 28, 64) query, key = embedding(query, key) query.shape torch.Size([2, 4, 28, 64])
For extended context length¶
embedding_extended = RotaryEmbedding(embedding_dim=64, scale=2.0) embedding.state_dict() # no weight in state_dict OrderedDict()
Source code in multimolecule/modules/embeddings/rotary.py
| Python | |
|---|---|
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | |
__init__
¶
__init__(
embedding_dim: int,
base: float = 10000.0,
scale: float = 1.0,
dtype: dtype = float32,
)
Initialize rotary position embeddings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
int
|
Dimension of the embeddings (must be even) |
required |
|
float
|
Base for computing inverse frequencies. Defaults to 10000.0. |
10000.0
|
|
float
|
Scaling factor for frequencies. Values > 1.0 extend context length (e.g., scale=2.0 doubles the effective context). Defaults to 1.0. |
1.0
|
|
dtype
|
Data type for computations. Defaults to torch.float32. |
float32
|
Source code in multimolecule/modules/embeddings/rotary.py
forward
¶
forward(
q: Tensor,
k: Tensor,
offset: int = 0,
seq_length: int | None = None,
) -> Tuple[Tensor, Tensor]
Apply rotary position embeddings to query and key tensors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
Query tensor of shape |
required |
|
Tensor
|
Key tensor of shape |
required |
|
int
|
Position offset for the start of the sequence (used with past_key_values). Defaults to 0. |
0
|
|
int | None
|
Full sequence length including offset. If None, uses the sequence length from the input tensors. Required when offset > 0. |
None
|
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor]
|
Tuple of (rotated_query, rotated_key) tensors with the same shapes as inputs. |
Source code in multimolecule/modules/embeddings/rotary.py
apply_rotary_pos_emb
¶
Apply rotary position embeddings to a tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
Input tensor of shape |
required |
|
int
|
Position offset for the start of the sequence (used with past_key_values). Defaults to 0. |
0
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Rotated tensor with the same shape as input. |
Source code in multimolecule/modules/embeddings/rotary.py
SinusoidalEmbedding
¶
Bases: Embedding
Sinusoidal positional embeddings for inputs with any length.
Freezing
The embeddings are frozen and cannot be trained. They will not be saved in the model’s state_dict.
Padding Idx
Padding symbols are ignored if the padding_idx is specified.
Sequence Length
These embeddings get automatically extended in forward if more positions is needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
int
|
The number of embeddings to use. |
required |
|
int
|
The dimension of the embeddings. |
required |
|
int | None
|
The index of the padding symbol. |
None
|
|
int
|
The bias of the embeddings. |
1
|
Example
embedding = SinusoidalEmbedding(num_embeddings=128, embedding_dim=64) input_ids = torch.arange(28).repeat(4).view(4, -1) input_embeds = torch.randn(4, 28, 64) embeddings = embedding(input_ids) embeddings.shape # no batch dimension if padding_idx is None torch.Size([28, 64]) input_embeds = input_embeds + embeddings input_embeds.shape torch.Size([4, 28, 64]) embedding = SinusoidalEmbedding(num_embeddings=128, embedding_dim=64, padding_idx=0) embeddings = embedding(input_ids) embeddings.shape # batch dimension if padding_idx is not None torch.Size([4, 28, 64]) embedding.state_dict() # no weight in state_dict OrderedDict()
Source code in multimolecule/modules/embeddings/sinusoidal.py
| Python | |
|---|---|
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | |
get_embedding
staticmethod
¶
get_embedding(
num_embeddings: int,
embedding_dim: int,
padding_idx: int | None = None,
device: device | None = None,
dtype: dtype = float32,
) -> Tensor
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of “Attention Is All You Need”.
Source code in multimolecule/modules/embeddings/sinusoidal.py
get_position_ids
staticmethod
¶
Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.