embeddings¶
embeddings
provide a collection of pre-defined positional embeddings.
multimolecule.modules.embeddings
¶
RotaryEmbedding
¶
Bases: Module
Rotary position embeddings based on those in RoFormer.
Query and keys are transformed by rotation matrices which depend on their relative positions.
Cache
The inverse frequency buffer is cached and updated only when the sequence length changes or the device changes.
Sequence Length
Rotary Embedding is irrespective of the sequence length and can be used for any sequence length.
Example
embedding = RotaryEmbedding(embedding_dim=64) query, key = torch.randn(2, 4, 28, 64), torch.randn(2, 4, 28, 64) query, key = embedding(query, key) query.shape torch.Size([2, 4, 28, 64]) embedding.state_dict() # no weight in state_dict OrderedDict()
Source code in multimolecule/modules/embeddings/rotary.py
SinusoidalEmbedding
¶
Bases: Embedding
Sinusoidal positional embeddings for inputs with any length.
Freezing
The embeddings are frozen and cannot be trained. They will not be saved in the model’s state_dict.
Padding Idx
Padding symbols are ignored if the padding_idx is specified.
Sequence Length
These embeddings get automatically extended in forward if more positions is needed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
int
|
The number of embeddings to use. |
required |
|
int
|
The dimension of the embeddings. |
required |
|
int | None
|
The index of the padding symbol. |
None
|
|
int
|
The bias of the embeddings. |
1
|
Example
embedding = SinusoidalEmbedding(num_embeddings=128, embedding_dim=64) input_ids = torch.arange(28).repeat(4).view(4, -1) input_embeds = torch.randn(4, 28, 64) embeddings = embedding(input_ids) embeddings.shape # no batch dimension if padding_idx is None torch.Size([28, 64]) input_embeds = input_embeds + embeddings input_embeds.shape torch.Size([4, 28, 64]) embedding = SinusoidalEmbedding(num_embeddings=128, embedding_dim=64, padding_idx=0) embeddings = embedding(input_ids) embeddings.shape # batch dimension if padding_idx is not None torch.Size([4, 28, 64]) embedding.state_dict() # no weight in state_dict OrderedDict()
Source code in multimolecule/modules/embeddings/sinusoidal.py
Python | |
---|---|
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
|
get_embedding
staticmethod
¶
get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: int | None = None, device: device | None = None, dtype: dtype = float32) -> Tensor
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of “Attention Is All You Need”.
Source code in multimolecule/modules/embeddings/sinusoidal.py
get_position_ids
staticmethod
¶
Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.