DNABERT-S
DNABERT-S
Pre-trained model on multi-species genome using a contrastive learning objective for species-aware DNA embeddings.
Disclaimer
This is an UNOFFICIAL implementation of the DNABERT-S: pioneering species differentiation with species-aware DNA embeddings by Zhihan Zhou, et al.
The OFFICIAL repository of DNABERT-S is at MAGICS-LAB/DNABERT_S.
Tip
The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.
The team releasing DNABERT-S did not write this model card for this model so this model card has been written by the MultiMolecule team.
Model Details
DNABERT-S is a bert-style model built upon DNABERT-2 and fine-tuned with contrastive learning for species-aware DNA embeddings. The model was trained using the proposed Curriculum Contrastive Learning (C²LR) strategy with the Manifold Instance Mixup (MI-Mix) training objective.
DNABERT-S shares the same architecture as DNABERT-2: it uses Byte Pair Encoding (BPE) tokenization, Attention with Linear Biases (ALiBi) instead of learned position embeddings, and incorporates a Gated Linear Unit (GeGLU) MLP and FlashAttention for improved efficiency.
Model Specification
| Num Layers |
Hidden Size |
Num Heads |
Intermediate Size |
Num Parameters (M) |
FLOPs (G) |
MACs (G) |
Max Num Tokens |
| 12 |
768 |
12 |
3072 |
117.07 |
125.83 |
62.92 |
512 |
Links
Usage
The model file depends on the multimolecule library. You can install it using pip:
| Bash |
|---|
| pip install multimolecule
|
Direct Use
You can use this model directly with a pipeline for feature extraction:
| Python |
|---|
| import multimolecule # you must import multimolecule to register models
from transformers import pipeline
predictor = pipeline("feature-extraction", model="multimolecule/dnaberts")
output = predictor("ATCGATCGATCG")
|
Downstream Use
Here is how to use this model to get the features of a given sequence in PyTorch:
| Python |
|---|
| from multimolecule import DnaBertSModel
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("multimolecule/dnaberts")
model = DnaBertSModel.from_pretrained("multimolecule/dnaberts")
text = "ATCGATCGATCGATCG"
input = tokenizer(text, return_tensors="pt")
output = model(**input)
|
Sequence Classification / Regression
Note
This model is not fine-tuned for any specific task. You will need to fine-tune the model on a downstream task to use it for sequence classification or regression.
Here is how to use this model as backbone to fine-tune for a sequence-level task in PyTorch:
| Python |
|---|
| import torch
from multimolecule import DnaBertSForSequencePrediction
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("multimolecule/dnaberts")
model = DnaBertSForSequencePrediction.from_pretrained("multimolecule/dnaberts")
text = "ATCGATCGATCGATCG"
input = tokenizer(text, return_tensors="pt")
label = torch.tensor([1])
output = model(**input, labels=label)
|
Token Classification / Regression
Note
This model is not fine-tuned for any specific task. You will need to fine-tune the model on a downstream task to use it for token classification or regression.
Here is how to use this model as backbone to fine-tune for a nucleotide-level task in PyTorch:
| Python |
|---|
| import torch
from multimolecule import DnaBertSForTokenPrediction
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("multimolecule/dnaberts")
model = DnaBertSForTokenPrediction.from_pretrained("multimolecule/dnaberts")
text = "ATCGATCGATCGATCG"
input = tokenizer(text, return_tensors="pt")
label = torch.randint(2, (len(text), ))
output = model(**input, labels=label)
|
Note
This model is not fine-tuned for any specific task. You will need to fine-tune the model on a downstream task to use it for contact classification or regression.
Here is how to use this model as backbone to fine-tune for a contact-level task in PyTorch:
| Python |
|---|
| import torch
from multimolecule import DnaBertSForContactPrediction
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("multimolecule/dnaberts")
model = DnaBertSForContactPrediction.from_pretrained("multimolecule/dnaberts")
text = "ATCGATCGATCGATCG"
input = tokenizer(text, return_tensors="pt")
label = torch.randint(2, (len(text), len(text)))
output = model(**input, labels=label)
|
Training Details
DNABERT-S uses a two-phase Curriculum Contrastive Learning (C²LR) strategy. In phase I, the model is trained with Weighted SimCLR for one epoch. In phase II, the model is further trained with Manifold Instance Mixup (MI-Mix) for two epochs. The training starts from the pre-trained DNABERT-2 checkpoint.
Training Data
The DNABERT-S model was trained on pairs of non-overlapping DNA sequences from the same species, sourced from GenBank. The dataset consists of 47,923 pairs from 17,636 viral genomes, 1 million pairs from 5,011 fungi genomes, and 1 million pairs from 6,402 bacteria genomes. From the total of 2,047,923 pairs, 2 million were randomly selected for training and the rest were used as validation data. All DNA sequences are 10,000 bp in length.
Training Procedure
Pre-training
The model was trained on 8 NVIDIA A100 80GB GPUs.
- Temperature (τ): 0.05
- Hyperparameter (α): 1.0
- Epochs: 1 (phase I, Weighted SimCLR) + 2 (phase II, MI-Mix)
- Optimizer: Adam
- Learning rate: 3e-6
- Batch size: 48
- Checkpointing: Every 10,000 steps, best selected on validation loss
- Training time: ~48 hours
Citation
| BibTeX |
|---|
| @article{zhou2025dnaberts,
title={{DNABERT-S}: pioneering species differentiation with species-aware {DNA} embeddings},
author={Zhou, Zhihan and Wu, Weimin and Ho, Harrison and Wang, Jiayi and Shi, Lizhen and Davuluri, Ramana V and Wang, Zhong and Liu, Han},
journal={Bioinformatics},
volume={41},
pages={i255--i264},
year={2025},
doi={10.1093/bioinformatics/btaf188}
}
|
Note
The artifacts distributed in this repository are part of the MultiMolecule project.
If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:
| BibTeX |
|---|
| @software{chen_2024_12638419,
author = {Chen, Zhiyuan and Zhu, Sophia Y.},
title = {MultiMolecule},
doi = {10.5281/zenodo.12638419},
publisher = {Zenodo},
url = {https://doi.org/10.5281/zenodo.12638419},
year = 2024,
month = may,
day = 4
}
|
Please use GitHub issues of MultiMolecule for any questions or comments on the model card.
Please contact the authors of the DNABERT-S paper for questions or comments on the paper/model.
License
This model implementation is licensed under the GNU Affero General Public License.
For additional terms and clarifications, please refer to our License FAQ.
| Text Only |
|---|
| SPDX-License-Identifier: AGPL-3.0-or-later
|
multimolecule.models.dnaberts
DnaBertSConfig
Bases: PreTrainedConfig
This is the configuration class to store the configuration of a
DnaBertSModel. It is used to instantiate a DNABERT-S model according
to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will
yield a similar configuration to that of the DNABERT-S
zhihan1996/DNABERT-S architecture.
Configuration objects inherit from PreTrainedConfig and can be used to
control the model outputs. Read the documentation from PreTrainedConfig
for more information.
Parameters:
| Name |
Type |
Description |
Default |
vocab_size
|
int
|
Vocabulary size of the DnaBertS model. Defines the number of different tokens that can be represented by
the input_ids passed when calling [DnaBertSModel].
|
4096
|
hidden_size
|
int
|
Dimensionality of the encoder layers and the pooler layer.
|
768
|
num_hidden_layers
|
int
|
Number of hidden layers in the Transformer encoder.
|
12
|
num_attention_heads
|
int
|
Number of attention heads for each attention layer in the Transformer encoder.
|
12
|
|
|
int
|
Dimensionality of the “intermediate” (often named feed-forward) layer in the Transformer encoder.
|
3072
|
hidden_act
|
str
|
The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu",
"relu", "silu" and "gelu_new" are supported.
|
'gelu'
|
hidden_dropout
|
float
|
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
0.1
|
attention_dropout
|
float
|
The dropout ratio for the attention probabilities.
|
0.0
|
max_position_embeddings
|
int
|
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
|
512
|
initializer_range
|
float
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
0.02
|
layer_norm_eps
|
float
|
The epsilon used by the layer normalization layers.
|
1e-12
|
position_embedding_type
|
str
|
Type of position embedding. DNABERT-S uses "alibi" (Attention with Linear Biases).
|
'alibi'
|
alibi_starting_size
|
int
|
The starting size for the ALiBi position bias tensor.
|
512
|
is_decoder
|
bool
|
Whether the model is used as a decoder or not. If False, the model is used as an encoder.
|
False
|
use_cache
|
bool
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if config.is_decoder=True.
|
True
|
head
|
HeadConfig | None
|
The configuration of the head.
|
None
|
lm_head
|
MaskedLMHeadConfig | None
|
The configuration of the masked language model head.
|
None
|
add_cross_attention
|
bool
|
Whether to add cross-attention layers when the model is used as a decoder.
|
False
|
Examples:
| Python Console Session |
|---|
| >>> from multimolecule import DnaBertSConfig, DnaBertSModel
>>> # Initializing a DNABERT-S multimolecule/dnaberts style configuration
>>> configuration = DnaBertSConfig()
>>> # Initializing a model (with random weights) from the multimolecule/dnaberts style configuration
>>> model = DnaBertSModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
|
Source code in multimolecule/models/dnaberts/configuration_dnaberts.py
| Python |
|---|
| class DnaBertSConfig(PreTrainedConfig):
r"""
This is the configuration class to store the configuration of a
[`DnaBertSModel`][multimolecule.models.DnaBertSModel]. It is used to instantiate a DNABERT-S model according
to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will
yield a similar configuration to that of the DNABERT-S
[zhihan1996/DNABERT-S](https://huggingface.co/zhihan1996/DNABERT-S) architecture.
Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
for more information.
Args:
vocab_size:
Vocabulary size of the DnaBertS model. Defines the number of different tokens that can be represented by
the `input_ids` passed when calling [`DnaBertSModel`].
hidden_size:
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers:
Number of hidden layers in the Transformer encoder.
num_attention_heads:
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size:
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act:
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
hidden_dropout:
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout:
The dropout ratio for the attention probabilities.
max_position_embeddings:
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range:
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps:
The epsilon used by the layer normalization layers.
position_embedding_type:
Type of position embedding. DNABERT-S uses `"alibi"` (Attention with Linear Biases).
alibi_starting_size:
The starting size for the ALiBi position bias tensor.
is_decoder:
Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
use_cache:
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
head:
The configuration of the head.
lm_head:
The configuration of the masked language model head.
add_cross_attention:
Whether to add cross-attention layers when the model is used as a decoder.
Examples:
>>> from multimolecule import DnaBertSConfig, DnaBertSModel
>>> # Initializing a DNABERT-S multimolecule/dnaberts style configuration
>>> configuration = DnaBertSConfig()
>>> # Initializing a model (with random weights) from the multimolecule/dnaberts style configuration
>>> model = DnaBertSModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "dnaberts"
def __init__(
self,
vocab_size: int = 4096,
hidden_size: int = 768,
num_hidden_layers: int = 12,
num_attention_heads: int = 12,
intermediate_size: int = 3072,
hidden_act: str = "gelu",
hidden_dropout: float = 0.1,
attention_dropout: float = 0.0,
max_position_embeddings: int = 512,
initializer_range: float = 0.02,
layer_norm_eps: float = 1e-12,
position_embedding_type: str = "alibi",
alibi_starting_size: int = 512,
is_decoder: bool = False,
use_cache: bool = True,
head: HeadConfig | None = None,
lm_head: MaskedLMHeadConfig | None = None,
add_cross_attention: bool = False,
**kwargs,
):
super().__init__(**kwargs)
validate_attention_dimensions(hidden_size, num_attention_heads)
self.vocab_size = vocab_size
self.type_vocab_size = 2
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.alibi_starting_size = alibi_starting_size
self.is_decoder = is_decoder
self.use_cache = use_cache
self.head = HeadConfig(**head) if head is not None else None
self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
self.add_cross_attention = add_cross_attention
|
Bases: DnaBertSPreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import DnaBertSConfig, DnaBertSForContactPrediction
>>> config = DnaBertSConfig()
>>> model = DnaBertSForContactPrediction(config)
>>> input_ids = torch.randint(0, config.vocab_size, (1, 16))
>>> output = model(input_ids, labels=torch.randint(2, (1, 14, 14)))
>>> output["logits"].shape
torch.Size([1, 14, 14, 1])
>>> output["loss"]
tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
|
Source code in multimolecule/models/dnaberts/modeling_dnaberts.py
| Python |
|---|
| class DnaBertSForContactPrediction(DnaBertSPreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import DnaBertSConfig, DnaBertSForContactPrediction
>>> config = DnaBertSConfig()
>>> model = DnaBertSForContactPrediction(config)
>>> input_ids = torch.randint(0, config.vocab_size, (1, 16))
>>> output = model(input_ids, labels=torch.randint(2, (1, 14, 14)))
>>> output["logits"].shape
torch.Size([1, 14, 14, 1])
>>> output["loss"] # doctest:+ELLIPSIS
tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
"""
def __init__(self, config: DnaBertSConfig):
super().__init__(config)
self.model = DnaBertSModel(config, add_pooling_layer=False)
self.contact_head = ContactPredictionHead(config)
self.head_config = self.contact_head.config
self.require_attentions = self.contact_head.require_attentions
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
labels: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[Tensor, ...] | ContactPredictorOutput:
if self.require_attentions:
output_attentions = kwargs.get("output_attentions", self.config.output_attentions)
if output_attentions is False:
warn("output_attentions must be True since prediction head requires attentions.")
kwargs["output_attentions"] = True
outputs = self.model(
input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
return_dict=True,
**kwargs,
)
output = self.contact_head(outputs, attention_mask, input_ids, labels)
logits, loss = output.logits, output.loss
return ContactPredictorOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|
DnaBertSForSequencePrediction
Bases: DnaBertSPreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import DnaBertSConfig, DnaBertSForSequencePrediction
>>> config = DnaBertSConfig()
>>> model = DnaBertSForSequencePrediction(config)
>>> input_ids = torch.randint(0, config.vocab_size, (1, 16))
>>> output = model(input_ids, labels=torch.tensor([[1]]))
>>> output["logits"].shape
torch.Size([1, 1])
>>> output["loss"]
tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
|
Source code in multimolecule/models/dnaberts/modeling_dnaberts.py
| Python |
|---|
| class DnaBertSForSequencePrediction(DnaBertSPreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import DnaBertSConfig, DnaBertSForSequencePrediction
>>> config = DnaBertSConfig()
>>> model = DnaBertSForSequencePrediction(config)
>>> input_ids = torch.randint(0, config.vocab_size, (1, 16))
>>> output = model(input_ids, labels=torch.tensor([[1]]))
>>> output["logits"].shape
torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
"""
def __init__(self, config: DnaBertSConfig):
super().__init__(config)
self.model = DnaBertSModel(config)
self.sequence_head = SequencePredictionHead(config)
self.head_config = self.sequence_head.config
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
labels: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[Tensor, ...] | SequencePredictorOutput:
outputs = self.model(
input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
return_dict=True,
**kwargs,
)
output = self.sequence_head(outputs, labels)
logits, loss = output.logits, output.loss
return SequencePredictorOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|
DnaBertSForTokenPrediction
Bases: DnaBertSPreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import DnaBertSConfig, DnaBertSForTokenPrediction
>>> config = DnaBertSConfig()
>>> model = DnaBertSForTokenPrediction(config)
>>> input_ids = torch.randint(0, config.vocab_size, (1, 16))
>>> output = model(input_ids, labels=torch.randint(2, (1, 14)))
>>> output["logits"].shape
torch.Size([1, 14, 1])
>>> output["loss"]
tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
|
Source code in multimolecule/models/dnaberts/modeling_dnaberts.py
| Python |
|---|
| class DnaBertSForTokenPrediction(DnaBertSPreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import DnaBertSConfig, DnaBertSForTokenPrediction
>>> config = DnaBertSConfig()
>>> model = DnaBertSForTokenPrediction(config)
>>> input_ids = torch.randint(0, config.vocab_size, (1, 16))
>>> output = model(input_ids, labels=torch.randint(2, (1, 14)))
>>> output["logits"].shape
torch.Size([1, 14, 1])
>>> output["loss"] # doctest:+ELLIPSIS
tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
"""
def __init__(self, config: DnaBertSConfig):
super().__init__(config)
self.model = DnaBertSModel(config, add_pooling_layer=False)
self.token_head = TokenPredictionHead(config)
self.head_config = self.token_head.config
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
labels: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[Tensor, ...] | TokenPredictorOutput:
outputs = self.model(
input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
return_dict=True,
**kwargs,
)
output = self.token_head(outputs, attention_mask, input_ids, labels)
logits, loss = output.logits, output.loss
return TokenPredictorOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|
DnaBertSModel
Bases: DnaBertSPreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import DnaBertSConfig, DnaBertSModel
>>> config = DnaBertSConfig()
>>> model = DnaBertSModel(config)
>>> input_ids = torch.randint(0, config.vocab_size, (1, 16))
>>> output = model(input_ids)
>>> output["last_hidden_state"].shape
torch.Size([1, 16, 768])
>>> output["pooler_output"].shape
torch.Size([1, 768])
|
Source code in multimolecule/models/dnaberts/modeling_dnaberts.py
| Python |
|---|
| class DnaBertSModel(DnaBertSPreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import DnaBertSConfig, DnaBertSModel
>>> config = DnaBertSConfig()
>>> model = DnaBertSModel(config)
>>> input_ids = torch.randint(0, config.vocab_size, (1, 16))
>>> output = model(input_ids)
>>> output["last_hidden_state"].shape
torch.Size([1, 16, 768])
>>> output["pooler_output"].shape
torch.Size([1, 768])
"""
def __init__(self, config: DnaBertSConfig, add_pooling_layer: bool = True):
super().__init__(config)
self.pad_token_id = config.pad_token_id
self.gradient_checkpointing = False
self.embeddings = DnaBertSEmbeddings(config)
self.encoder = DnaBertSEncoder(config)
self.pooler = DnaBertSPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
@merge_with_config_defaults
@capture_outputs
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
encoder_hidden_states: Tensor | None = None,
encoder_attention_mask: Tensor | None = None,
past_key_values: Cache | None = None,
use_cache: bool | None = None,
cache_position: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[Tensor, ...] | BaseModelOutputWithPoolingAndCrossAttentions:
r"""
Args:
encoder_hidden_states:
Shape: `(batch_size, sequence_length, hidden_size)`
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask:
Shape: `(batch_size, sequence_length)`
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values:
Tuple of length `config.n_layers` with each tuple having 4 tensors of shape
`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache:
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
"""
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if isinstance(input_ids, NestedTensor) and attention_mask is None:
attention_mask = input_ids.mask
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if input_ids is not None:
device = input_ids.device
seq_length = input_ids.shape[1]
else:
device = inputs_embeds.device # type: ignore[union-attr]
seq_length = inputs_embeds.shape[1] # type: ignore[union-attr]
# past_key_values_length
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
if attention_mask is None and input_ids is not None and self.pad_token_id is not None:
attention_mask = input_ids.ne(self.pad_token_id)
embedding_output = self.embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
)
attention_mask, encoder_attention_mask = self._create_attention_masks(
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
embedding_output=embedding_output,
encoder_hidden_states=encoder_hidden_states,
cache_position=cache_position,
past_key_values=past_key_values,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
sequence_output = encoder_outputs.last_hidden_state
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
)
def _create_attention_masks(
self,
attention_mask,
encoder_attention_mask,
embedding_output,
encoder_hidden_states,
cache_position,
past_key_values,
):
if self.config.is_decoder:
attention_mask = create_causal_mask(
config=self.config,
inputs_embeds=embedding_output,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
else:
attention_mask = create_bidirectional_mask(
config=self.config, inputs_embeds=embedding_output, attention_mask=attention_mask
)
if encoder_attention_mask is not None:
encoder_attention_mask = create_bidirectional_mask(
config=self.config,
inputs_embeds=embedding_output,
attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
)
return attention_mask, encoder_attention_mask
|
forward
Parameters:
| Name |
Type |
Description |
Default |
encoder_hidden_states
|
Tensor | None
|
Shape: (batch_size, sequence_length, hidden_size)
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
|
None
|
encoder_attention_mask
|
Tensor | None
|
Shape: (batch_size, sequence_length)
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in [0, 1]:
- 1 for tokens that are not masked,
- 0 for tokens that are masked.
|
None
|
past_key_values
|
Cache | None
|
Tuple of length config.n_layers with each tuple having 4 tensors of shape
`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
decoding.
If past_key_values are used, the user can optionally input only the last decoder_input_ids (those
that don’t have their past key value states given to this model) of shape (batch_size, 1) instead of
all decoder_input_ids of shape (batch_size, sequence_length).
|
None
|
use_cache
|
bool | None
|
If set to True, past_key_values key value states are returned and can be used to speed up decoding
(see past_key_values).
|
None
|
Source code in multimolecule/models/dnaberts/modeling_dnaberts.py
| Python |
|---|
| @merge_with_config_defaults
@capture_outputs
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
encoder_hidden_states: Tensor | None = None,
encoder_attention_mask: Tensor | None = None,
past_key_values: Cache | None = None,
use_cache: bool | None = None,
cache_position: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[Tensor, ...] | BaseModelOutputWithPoolingAndCrossAttentions:
r"""
Args:
encoder_hidden_states:
Shape: `(batch_size, sequence_length, hidden_size)`
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask:
Shape: `(batch_size, sequence_length)`
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values:
Tuple of length `config.n_layers` with each tuple having 4 tensors of shape
`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache:
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
"""
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if isinstance(input_ids, NestedTensor) and attention_mask is None:
attention_mask = input_ids.mask
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if input_ids is not None:
device = input_ids.device
seq_length = input_ids.shape[1]
else:
device = inputs_embeds.device # type: ignore[union-attr]
seq_length = inputs_embeds.shape[1] # type: ignore[union-attr]
# past_key_values_length
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
if attention_mask is None and input_ids is not None and self.pad_token_id is not None:
attention_mask = input_ids.ne(self.pad_token_id)
embedding_output = self.embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
)
attention_mask, encoder_attention_mask = self._create_attention_masks(
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
embedding_output=embedding_output,
encoder_hidden_states=encoder_hidden_states,
cache_position=cache_position,
past_key_values=past_key_values,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
sequence_output = encoder_outputs.last_hidden_state
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
)
|
DnaBertSPreTrainedModel
Bases: PreTrainedModel
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
Source code in multimolecule/models/dnaberts/modeling_dnaberts.py
| Python |
|---|
| class DnaBertSPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = DnaBertSConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_attention_backend = True
_can_record_outputs: dict[str, Any] | None = None
_no_split_modules = ["DnaBertSLayer", "DnaBertSEmbeddings"]
@torch.no_grad()
def _init_weights(self, module: nn.Module):
super()._init_weights(module)
|