ProteinBERT
Pre-trained model on protein sequences and Gene Ontology annotations using a combined language modeling and annotation prediction objective.
Disclaimer
This is an UNOFFICIAL implementation of the ProteinBERT: a universal deep-learning model of protein sequence and function by Nadav Brandes, et al.
The OFFICIAL repository of ProteinBERT is at nadavbra/protein_bert.
Tip
The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.
The team releasing ProteinBERT did not write this model card for this model so this model card has been written by the MultiMolecule team.
Model Details
ProteinBERT is a protein language model with coupled local residue representations and a global protein representation.
It is pre-trained on UniRef90 with a sequence language modeling objective and a Gene Ontology annotation recovery objective.
ProteinBERT uses convolutional local branches and global-attention layers instead of quadratic self-attention, so the architecture has no learned positional table and can be evaluated on variable sequence lengths.
Model Specification
| Num Layers |
Hidden Size |
Global Hidden Size |
Num Heads |
Num Parameters (M) |
FLOPs (G) |
MACs (G) |
Max Num Tokens |
| 6 |
128 |
512 |
4 |
15.98 |
7.16 |
3.54 |
1024 |
Links
Usage
The model file depends on the multimolecule library. You can install it using pip:
| Bash |
|---|
| pip install multimolecule
|
Direct Use
Masked Language Modeling
You can use this model directly with a pipeline for masked language modeling:
| Python |
|---|
| import multimolecule # you must import multimolecule to register models
from transformers import pipeline
predictor = pipeline("fill-mask", model="multimolecule/proteinbert")
output = predictor("MVLSPADKTNVKAAW<mask>KVGAHAGEYGAEALER")
|
Downstream Use
Here is how to use this model to get the features of a given sequence in PyTorch:
| Python |
|---|
| from multimolecule import ProteinTokenizer, ProteinBertModel
tokenizer = ProteinTokenizer.from_pretrained("multimolecule/proteinbert")
model = ProteinBertModel.from_pretrained("multimolecule/proteinbert")
text = "MVLSPADKTNVKAAWGKVGAHAGEYGAEALER"
input = tokenizer(text, return_tensors="pt")
output = model(**input)
|
Sequence Classification / Regression
Note
This model is not fine-tuned for any specific task. You will need to fine-tune the model on a downstream task to use it for sequence classification or regression.
Here is how to use this model as backbone to fine-tune for a sequence-level task in PyTorch:
| Python |
|---|
| import torch
from multimolecule import ProteinTokenizer, ProteinBertForSequencePrediction
tokenizer = ProteinTokenizer.from_pretrained("multimolecule/proteinbert")
model = ProteinBertForSequencePrediction.from_pretrained("multimolecule/proteinbert")
text = "MVLSPADKTNVKAAWGKVGAHAGEYGAEALER"
input = tokenizer(text, return_tensors="pt")
label = torch.tensor([1])
output = model(**input, labels=label)
|
Token Classification / Regression
Note
This model is not fine-tuned for any specific task. You will need to fine-tune the model on a downstream task to use it for token classification or regression.
Here is how to use this model as backbone to fine-tune for a residue-level task in PyTorch:
| Python |
|---|
| import torch
from multimolecule import ProteinTokenizer, ProteinBertForTokenPrediction
tokenizer = ProteinTokenizer.from_pretrained("multimolecule/proteinbert")
model = ProteinBertForTokenPrediction.from_pretrained("multimolecule/proteinbert")
text = "MVLSPADKTNVKAAWGKVGAHAGEYGAEALER"
input = tokenizer(text, return_tensors="pt")
label = torch.randint(2, (1, len(text)))
output = model(**input, labels=label)
|
Training Details
Training Data
ProteinBERT is pre-trained on approximately 106 million protein sequences from UniRef90 and Gene Ontology annotations.
Training Procedure
ProteinBERT is trained with a combined objective over masked protein sequence recovery and Gene Ontology annotation prediction.
Please refer to the original paper for details on the training setup.
Citation
| BibTeX |
|---|
| @article{brandes2022proteinbert,
title = {ProteinBERT: a universal deep-learning model of protein sequence and function},
author = {Brandes, Nadav and Ofer, Dan and Peleg, Yam and Rappoport, Nadav and Linial, Michal},
year = {2022},
journal = {Bioinformatics},
volume = {38},
number = {8},
pages = {2102--2110},
doi = {10.1093/bioinformatics/btac020},
url = {https://doi.org/10.1093/bioinformatics/btac020},
}
|
Note
The artifacts distributed in this repository are part of the MultiMolecule project.
If MultiMolecule supports your research, please cite the MultiMolecule project as follows:
| BibTeX |
|---|
| @software{chen_2024_12638419,
author = {Chen, Zhiyuan and Zhu, Sophia Y.},
title = {MultiMolecule},
doi = {10.5281/zenodo.12638419},
publisher = {Zenodo},
url = {https://doi.org/10.5281/zenodo.12638419},
year = 2024,
month = may,
day = 4
}
|
Please use GitHub issues of MultiMolecule for any questions or comments on the model card.
Please contact the authors of the ProteinBERT paper for questions or comments on the paper/model.
License
This model implementation is licensed under the GNU Affero General Public License.
For additional terms and clarifications, please refer to our License FAQ.
| Text Only |
|---|
| SPDX-License-Identifier: AGPL-3.0-or-later
|
API Reference
ProteinBertConfig
Bases: PreTrainedConfig
This is the configuration class to store the configuration of a
ProteinBertModel. It is used to instantiate a ProteinBERT model
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the official ProteinBERT checkpoint.
Configuration objects inherit from PreTrainedConfig and can be used to
control the model outputs. Read the documentation from PreTrainedConfig
for more information.
Parameters:
| Name |
Type |
Description |
Default |
vocab_size
|
int
|
Vocabulary size of the ProteinBERT model. Defines the number of different tokens that can be represented by
the input_ids passed when calling [ProteinBertModel].
|
37
|
hidden_size
|
int
|
Dimensionality of the local residue representations.
|
128
|
global_hidden_size
|
int
|
Dimensionality of the global protein representation.
|
512
|
annotation_size
|
int
|
Number of Gene Ontology annotation channels used by the pretraining objective.
|
8943
|
num_hidden_layers
|
int
|
Number of ProteinBERT local/global encoder blocks.
|
6
|
num_attention_heads
|
int
|
Number of global-attention heads in each encoder block.
|
4
|
attention_key_size
|
int
|
Dimensionality of each global-attention query/key head.
|
64
|
conv_kernel_size
|
int
|
Width of the local convolution kernels.
|
9
|
wide_conv_dilation_rate
|
int
|
Dilation rate of the wide local convolution branch.
|
5
|
hidden_act
|
str
|
Non-linear activation function used by dense and convolutional branches.
|
'gelu'
|
initializer_range
|
float
|
Standard deviation used by common prediction heads.
|
0.02
|
layer_norm_eps
|
float
|
Epsilon used by layer normalization layers.
|
0.001
|
head
|
HeadConfig | None
|
The configuration of the downstream prediction head.
|
None
|
lm_head
|
MaskedLMHeadConfig | None
|
The configuration of the masked language model head.
|
None
|
Examples:
| Python Console Session |
|---|
| >>> from multimolecule import ProteinBertConfig, ProteinBertModel
>>> configuration = ProteinBertConfig()
>>> model = ProteinBertModel(configuration)
>>> configuration = model.config
|
Source code in multimolecule/models/proteinbert/configuration_proteinbert.py
| Python |
|---|
| class ProteinBertConfig(PreTrainedConfig):
r"""
This is the configuration class to store the configuration of a
[`ProteinBertModel`][multimolecule.models.ProteinBertModel]. It is used to instantiate a ProteinBERT model
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the official ProteinBERT checkpoint.
Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
for more information.
Args:
vocab_size:
Vocabulary size of the ProteinBERT model. Defines the number of different tokens that can be represented by
the `input_ids` passed when calling [`ProteinBertModel`].
hidden_size:
Dimensionality of the local residue representations.
global_hidden_size:
Dimensionality of the global protein representation.
annotation_size:
Number of Gene Ontology annotation channels used by the pretraining objective.
num_hidden_layers:
Number of ProteinBERT local/global encoder blocks.
num_attention_heads:
Number of global-attention heads in each encoder block.
attention_key_size:
Dimensionality of each global-attention query/key head.
conv_kernel_size:
Width of the local convolution kernels.
wide_conv_dilation_rate:
Dilation rate of the wide local convolution branch.
hidden_act:
Non-linear activation function used by dense and convolutional branches.
initializer_range:
Standard deviation used by common prediction heads.
layer_norm_eps:
Epsilon used by layer normalization layers.
head:
The configuration of the downstream prediction head.
lm_head:
The configuration of the masked language model head.
Examples:
>>> from multimolecule import ProteinBertConfig, ProteinBertModel
>>> configuration = ProteinBertConfig()
>>> model = ProteinBertModel(configuration)
>>> configuration = model.config
"""
model_type = "proteinbert"
def __init__(
self,
vocab_size: int = 37,
hidden_size: int = 128,
global_hidden_size: int = 512,
annotation_size: int = 8943,
num_hidden_layers: int = 6,
num_attention_heads: int = 4,
attention_key_size: int = 64,
conv_kernel_size: int = 9,
wide_conv_dilation_rate: int = 5,
hidden_act: str = "gelu",
initializer_range: float = 0.02,
layer_norm_eps: float = 1.0e-3,
pad_token_id: int = 0,
bos_token_id: int = 1,
eos_token_id: int = 2,
unk_token_id: int = 3,
mask_token_id: int = 4,
null_token_id: int = 5,
head: HeadConfig | None = None,
lm_head: MaskedLMHeadConfig | None = None,
**kwargs,
):
kwargs.setdefault("tie_word_embeddings", False)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
unk_token_id=unk_token_id,
mask_token_id=mask_token_id,
null_token_id=null_token_id,
**kwargs,
)
if global_hidden_size % num_attention_heads != 0:
raise ValueError(
"global_hidden_size must be divisible by num_attention_heads; got "
f"{global_hidden_size} and {num_attention_heads}."
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.global_hidden_size = global_hidden_size
self.annotation_size = annotation_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_key_size = attention_key_size
self.conv_kernel_size = conv_kernel_size
self.wide_conv_dilation_rate = wide_conv_dilation_rate
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.head = HeadConfig(**head) if head is not None else None
self.lm_head = (
MaskedLMHeadConfig(**lm_head)
if lm_head is not None
else MaskedLMHeadConfig(transform=None, transform_act=None, bias=True)
)
|
Bases: ProteinBertPreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertForMaskedLM
>>> config = ProteinBertConfig()
>>> model = ProteinBertForMaskedLM(config)
>>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
>>> output = model(input_ids, labels=input_ids)
>>> output["logits"].shape
torch.Size([1, 10, 37])
|
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
| Python |
|---|
| class ProteinBertForMaskedLM(ProteinBertPreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertForMaskedLM
>>> config = ProteinBertConfig()
>>> model = ProteinBertForMaskedLM(config)
>>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
>>> output = model(input_ids, labels=input_ids)
>>> output["logits"].shape
torch.Size([1, 10, 37])
"""
_tied_weights_keys = {
"lm_head.decoder.bias": "lm_head.bias",
}
def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict:
tied_weights = super().get_expanded_tied_weights_keys(all_submodels=all_submodels)
if all_submodels:
return tied_weights
return tied_weights | self._tied_weights_keys
def __init__(self, config: ProteinBertConfig):
super().__init__(config)
self.model = ProteinBertModel(config)
self.lm_head = MaskedLMHead(config)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, embeddings):
self.lm_head.decoder = embeddings
if hasattr(self.lm_head, "bias"):
self.lm_head.bias = embeddings.bias
@can_return_tuple
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
annotations: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
labels: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[Tensor, ...] | MaskedLMOutput:
outputs = self.model(
input_ids,
attention_mask=attention_mask,
annotations=annotations,
inputs_embeds=inputs_embeds,
return_dict=True,
**kwargs,
)
output = self.lm_head(outputs, labels)
logits, loss = output.logits, output.loss
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|
ProteinBertForPreTraining
Bases: ProteinBertPreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertForPreTraining
>>> config = ProteinBertConfig()
>>> model = ProteinBertForPreTraining(config)
>>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
>>> output = model(input_ids)
>>> output["logits"].shape
torch.Size([1, 10, 37])
>>> output["annotation_logits"].shape
torch.Size([1, 8943])
|
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
| Python |
|---|
| class ProteinBertForPreTraining(ProteinBertPreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertForPreTraining
>>> config = ProteinBertConfig()
>>> model = ProteinBertForPreTraining(config)
>>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
>>> output = model(input_ids)
>>> output["logits"].shape
torch.Size([1, 10, 37])
>>> output["annotation_logits"].shape
torch.Size([1, 8943])
"""
_tied_weights_keys = {
"lm_head.decoder.bias": "lm_head.bias",
}
def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict:
tied_weights = super().get_expanded_tied_weights_keys(all_submodels=all_submodels)
if all_submodels:
return tied_weights
return tied_weights | self._tied_weights_keys
def __init__(self, config: ProteinBertConfig):
super().__init__(config)
self.model = ProteinBertModel(config)
self.lm_head = MaskedLMHead(config)
self.annotation_head = ProteinBertAnnotationPredictionHead(config)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, embeddings):
self.lm_head.decoder = embeddings
if hasattr(self.lm_head, "bias"):
self.lm_head.bias = embeddings.bias
@can_return_tuple
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
annotations: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
labels: Tensor | None = None,
annotation_labels: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[Tensor, ...] | ProteinBertForPreTrainingOutput:
outputs = self.model(
input_ids,
attention_mask=attention_mask,
annotations=annotations,
inputs_embeds=inputs_embeds,
return_dict=True,
**kwargs,
)
lm_output = self.lm_head(outputs, labels)
annotation_logits = self.annotation_head(outputs.pooler_output)
loss = lm_output.loss
if annotation_labels is not None:
annotation_loss = F.binary_cross_entropy_with_logits(annotation_logits, annotation_labels.float())
loss = annotation_loss if loss is None else loss + annotation_loss
return ProteinBertForPreTrainingOutput(
loss=loss,
logits=lm_output.logits,
annotation_logits=annotation_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|
ProteinBertForPreTrainingOutput
dataclass
Bases: ModelOutput
Output type of [ProteinBertForPreTraining].
Parameters:
| Name |
Type |
Description |
Default |
loss
|
FloatTensor | None
|
Masked language modeling plus annotation prediction loss.
|
None
|
logits
|
FloatTensor | None
|
Prediction scores of the language modeling head.
|
None
|
annotation_logits
|
FloatTensor | None
|
Prediction scores of the Gene Ontology annotation head.
|
None
|
hidden_states
|
tuple[FloatTensor, ...] | None
|
Hidden states of the local representation stack.
|
None
|
attentions
|
tuple[FloatTensor, ...] | None
|
Global-attention probabilities for each layer.
|
None
|
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
| Python |
|---|
| @dataclass
class ProteinBertForPreTrainingOutput(ModelOutput):
"""
Output type of [`ProteinBertForPreTraining`].
Args:
loss:
Masked language modeling plus annotation prediction loss.
logits:
Prediction scores of the language modeling head.
annotation_logits:
Prediction scores of the Gene Ontology annotation head.
hidden_states:
Hidden states of the local representation stack.
attentions:
Global-attention probabilities for each layer.
"""
loss: torch.FloatTensor | None = None
logits: torch.FloatTensor | None = None
annotation_logits: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
attentions: tuple[torch.FloatTensor, ...] | None = None
|
ProteinBertForSequencePrediction
Bases: ProteinBertPreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertForSequencePrediction
>>> config = ProteinBertConfig()
>>> model = ProteinBertForSequencePrediction(config)
>>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
>>> output = model(input_ids, labels=torch.tensor([[1]]))
>>> output["logits"].shape
torch.Size([1, 1])
|
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
| Python |
|---|
| class ProteinBertForSequencePrediction(ProteinBertPreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertForSequencePrediction
>>> config = ProteinBertConfig()
>>> model = ProteinBertForSequencePrediction(config)
>>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
>>> output = model(input_ids, labels=torch.tensor([[1]]))
>>> output["logits"].shape
torch.Size([1, 1])
"""
def __init__(self, config: ProteinBertConfig):
super().__init__(config)
self.model = ProteinBertModel(config)
head_config = HeadConfig(config.head or {})
if head_config.hidden_size is None:
# ProteinBert exposes two feature streams of different width: the per-token `last_hidden_state`
# (hidden_size) and the global `pooler_output` (global_hidden_size). Sequence heads read the pooled
# stream by default, so any output other than `last_hidden_state` resolves to global_hidden_size.
head_config.hidden_size = (
config.hidden_size if head_config.output_name == "last_hidden_state" else config.global_hidden_size
)
self.sequence_head = SequencePredictionHead(config, head_config)
self.head_config = self.sequence_head.config
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
annotations: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
labels: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[Tensor, ...] | SequencePredictorOutput:
outputs = self.model(
input_ids,
attention_mask=attention_mask,
annotations=annotations,
inputs_embeds=inputs_embeds,
return_dict=True,
**kwargs,
)
output = self.sequence_head(outputs, labels)
logits, loss = output.logits, output.loss
return SequencePredictorOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|
ProteinBertForTokenPrediction
Bases: ProteinBertPreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertForTokenPrediction
>>> config = ProteinBertConfig()
>>> model = ProteinBertForTokenPrediction(config)
>>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
>>> output = model(input_ids, labels=torch.randint(2, (1, 8)))
>>> output["logits"].shape
torch.Size([1, 8, 1])
|
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
| Python |
|---|
| class ProteinBertForTokenPrediction(ProteinBertPreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertForTokenPrediction
>>> config = ProteinBertConfig()
>>> model = ProteinBertForTokenPrediction(config)
>>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
>>> output = model(input_ids, labels=torch.randint(2, (1, 8)))
>>> output["logits"].shape
torch.Size([1, 8, 1])
"""
def __init__(self, config: ProteinBertConfig):
super().__init__(config)
self.model = ProteinBertModel(config)
self.token_head = TokenPredictionHead(config)
self.head_config = self.token_head.config
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
annotations: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
labels: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[Tensor, ...] | TokenPredictorOutput:
outputs = self.model(
input_ids,
attention_mask=attention_mask,
annotations=annotations,
inputs_embeds=inputs_embeds,
return_dict=True,
**kwargs,
)
output = self.token_head(outputs, attention_mask, input_ids, labels)
logits, loss = output.logits, output.loss
return TokenPredictorOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|
ProteinBertModel
Bases: ProteinBertPreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertModel, ProteinTokenizer
>>> config = ProteinBertConfig()
>>> model = ProteinBertModel(config)
>>> tokenizer = ProteinTokenizer.from_pretrained("multimolecule/protein")
>>> input = tokenizer("MVLSPADKT", return_tensors="pt")
>>> output = model(**input)
>>> output["last_hidden_state"].shape
torch.Size([1, 11, 128])
>>> output["pooler_output"].shape
torch.Size([1, 512])
|
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
| Python |
|---|
| class ProteinBertModel(ProteinBertPreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertModel, ProteinTokenizer
>>> config = ProteinBertConfig()
>>> model = ProteinBertModel(config)
>>> tokenizer = ProteinTokenizer.from_pretrained("multimolecule/protein")
>>> input = tokenizer("MVLSPADKT", return_tensors="pt")
>>> output = model(**input)
>>> output["last_hidden_state"].shape
torch.Size([1, 11, 128])
>>> output["pooler_output"].shape
torch.Size([1, 512])
"""
def __init__(self, config: ProteinBertConfig):
super().__init__(config)
self.pad_token_id = config.pad_token_id
self.gradient_checkpointing = False
self.embeddings = ProteinBertEmbeddings(config)
self.encoder = ProteinBertEncoder(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
@can_return_tuple
@merge_with_config_defaults
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
annotations: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[Tensor, ...] | ProteinBertModelOutput:
if isinstance(input_ids, NestedTensor):
if attention_mask is None:
attention_mask = input_ids.mask
input_ids = input_ids.tensor
if isinstance(inputs_embeds, NestedTensor):
if attention_mask is None:
attention_mask = inputs_embeds.mask
inputs_embeds = inputs_embeds.tensor
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
if attention_mask is None:
if input_ids is not None and self.pad_token_id is not None:
attention_mask = input_ids.ne(self.pad_token_id)
else:
attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device)
else:
attention_mask = attention_mask.to(device=hidden_states.device, dtype=torch.bool)
hidden_states = hidden_states * attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
if annotations is None:
annotations = hidden_states.new_zeros(hidden_states.shape[0], self.config.annotation_size)
annotations = annotations.to(device=hidden_states.device, dtype=hidden_states.dtype)
global_states = self.embeddings.project_annotations(annotations)
encoder_outputs = self.encoder(
hidden_states,
global_states,
attention_mask=attention_mask,
output_hidden_states=kwargs.get("output_hidden_states", self.config.output_hidden_states),
output_attentions=kwargs.get("output_attentions", self.config.output_attentions),
)
return ProteinBertModelOutput(
last_hidden_state=encoder_outputs.last_hidden_state,
pooler_output=encoder_outputs.pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
|
ProteinBertModelOutput
dataclass
Bases: ModelOutput
Base class for ProteinBERT backbone outputs.
Parameters:
| Name |
Type |
Description |
Default |
last_hidden_state
|
FloatTensor | None
|
Local residue representations of shape (batch_size, sequence_length, hidden_size).
|
None
|
pooler_output
|
FloatTensor | None
|
Global protein representations of shape (batch_size, global_hidden_size).
|
None
|
hidden_states
|
tuple[FloatTensor, ...] | None
|
Hidden states of the local representation stack.
|
None
|
attentions
|
tuple[FloatTensor, ...] | None
|
Global-attention probabilities for each layer.
|
None
|
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
| Python |
|---|
| @dataclass
class ProteinBertModelOutput(ModelOutput):
"""
Base class for ProteinBERT backbone outputs.
Args:
last_hidden_state:
Local residue representations of shape `(batch_size, sequence_length, hidden_size)`.
pooler_output:
Global protein representations of shape `(batch_size, global_hidden_size)`.
hidden_states:
Hidden states of the local representation stack.
attentions:
Global-attention probabilities for each layer.
"""
last_hidden_state: torch.FloatTensor | None = None
pooler_output: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
attentions: tuple[torch.FloatTensor, ...] | None = None
|
ProteinBertPreTrainedModel
Bases: PreTrainedModel
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
| Python |
|---|
| class ProteinBertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ProteinBertConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_can_record_outputs: dict[str, Any] | None = None
_no_split_modules = ["ProteinBertLayer"]
@torch.no_grad()
def _init_weights(self, module: nn.Module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv1d)):
init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
init.normal_(module.weight, mean=0.0, std=std)
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
init.ones_(module.weight)
init.zeros_(module.bias)
elif isinstance(module, ProteinBertGlobalAttention):
init.xavier_uniform_(module.query)
init.xavier_uniform_(module.key)
init.xavier_uniform_(module.value)
|