跳转至

ProteinBERT

Pre-trained model on protein sequences and Gene Ontology annotations using a combined language modeling and annotation prediction objective.

Disclaimer

This is an UNOFFICIAL implementation of the ProteinBERT: a universal deep-learning model of protein sequence and function by Nadav Brandes, et al.

The OFFICIAL repository of ProteinBERT is at nadavbra/protein_bert.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing ProteinBERT did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

ProteinBERT is a protein language model with coupled local residue representations and a global protein representation. It is pre-trained on UniRef90 with a sequence language modeling objective and a Gene Ontology annotation recovery objective. ProteinBERT uses convolutional local branches and global-attention layers instead of quadratic self-attention, so the architecture has no learned positional table and can be evaluated on variable sequence lengths.

Model Specification

Num Layers Hidden Size Global Hidden Size Num Heads Num Parameters (M) FLOPs (G) MACs (G) Max Num Tokens
6 128 512 4 15.98 7.16 3.54 1024

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Masked Language Modeling

You can use this model directly with a pipeline for masked language modeling:

Python
1
2
3
4
5
import multimolecule  # you must import multimolecule to register models
from transformers import pipeline

predictor = pipeline("fill-mask", model="multimolecule/proteinbert")
output = predictor("MVLSPADKTNVKAAW<mask>KVGAHAGEYGAEALER")

Downstream Use

Extract Features

Here is how to use this model to get the features of a given sequence in PyTorch:

Python
from multimolecule import ProteinTokenizer, ProteinBertModel


tokenizer = ProteinTokenizer.from_pretrained("multimolecule/proteinbert")
model = ProteinBertModel.from_pretrained("multimolecule/proteinbert")

text = "MVLSPADKTNVKAAWGKVGAHAGEYGAEALER"
input = tokenizer(text, return_tensors="pt")

output = model(**input)

Sequence Classification / Regression

Note

This model is not fine-tuned for any specific task. You will need to fine-tune the model on a downstream task to use it for sequence classification or regression.

Here is how to use this model as backbone to fine-tune for a sequence-level task in PyTorch:

Python
import torch
from multimolecule import ProteinTokenizer, ProteinBertForSequencePrediction


tokenizer = ProteinTokenizer.from_pretrained("multimolecule/proteinbert")
model = ProteinBertForSequencePrediction.from_pretrained("multimolecule/proteinbert")

text = "MVLSPADKTNVKAAWGKVGAHAGEYGAEALER"
input = tokenizer(text, return_tensors="pt")
label = torch.tensor([1])

output = model(**input, labels=label)

Token Classification / Regression

Note

This model is not fine-tuned for any specific task. You will need to fine-tune the model on a downstream task to use it for token classification or regression.

Here is how to use this model as backbone to fine-tune for a residue-level task in PyTorch:

Python
import torch
from multimolecule import ProteinTokenizer, ProteinBertForTokenPrediction


tokenizer = ProteinTokenizer.from_pretrained("multimolecule/proteinbert")
model = ProteinBertForTokenPrediction.from_pretrained("multimolecule/proteinbert")

text = "MVLSPADKTNVKAAWGKVGAHAGEYGAEALER"
input = tokenizer(text, return_tensors="pt")
label = torch.randint(2, (1, len(text)))

output = model(**input, labels=label)

Training Details

Training Data

ProteinBERT is pre-trained on approximately 106 million protein sequences from UniRef90 and Gene Ontology annotations.

Training Procedure

ProteinBERT is trained with a combined objective over masked protein sequence recovery and Gene Ontology annotation prediction. Please refer to the original paper for details on the training setup.

Citation

BibTeX
@article{brandes2022proteinbert,
  title   = {ProteinBERT: a universal deep-learning model of protein sequence and function},
  author  = {Brandes, Nadav and Ofer, Dan and Peleg, Yam and Rappoport, Nadav and Linial, Michal},
  year    = {2022},
  journal = {Bioinformatics},
  volume  = {38},
  number  = {8},
  pages   = {2102--2110},
  doi     = {10.1093/bioinformatics/btac020},
  url     = {https://doi.org/10.1093/bioinformatics/btac020},
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If MultiMolecule supports your research, please cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the ProteinBERT paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

API Reference

ProteinBertConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a ProteinBertModel. It is used to instantiate a ProteinBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the official ProteinBERT checkpoint.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the ProteinBERT model. Defines the number of different tokens that can be represented by the input_ids passed when calling [ProteinBertModel].

37

hidden_size

int

Dimensionality of the local residue representations.

128

global_hidden_size

int

Dimensionality of the global protein representation.

512

annotation_size

int

Number of Gene Ontology annotation channels used by the pretraining objective.

8943

num_hidden_layers

int

Number of ProteinBERT local/global encoder blocks.

6

num_attention_heads

int

Number of global-attention heads in each encoder block.

4

attention_key_size

int

Dimensionality of each global-attention query/key head.

64

conv_kernel_size

int

Width of the local convolution kernels.

9

wide_conv_dilation_rate

int

Dilation rate of the wide local convolution branch.

5

hidden_act

str

Non-linear activation function used by dense and convolutional branches.

'gelu'

initializer_range

float

Standard deviation used by common prediction heads.

0.02

layer_norm_eps

float

Epsilon used by layer normalization layers.

0.001

head

HeadConfig | None

The configuration of the downstream prediction head.

None

lm_head

MaskedLMHeadConfig | None

The configuration of the masked language model head.

None

Examples:

Python Console Session
1
2
3
4
>>> from multimolecule import ProteinBertConfig, ProteinBertModel
>>> configuration = ProteinBertConfig()
>>> model = ProteinBertModel(configuration)
>>> configuration = model.config
Source code in multimolecule/models/proteinbert/configuration_proteinbert.py
Python
class ProteinBertConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`ProteinBertModel`][multimolecule.models.ProteinBertModel]. It is used to instantiate a ProteinBERT model
    according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the official ProteinBERT checkpoint.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Vocabulary size of the ProteinBERT model. Defines the number of different tokens that can be represented by
            the `input_ids` passed when calling [`ProteinBertModel`].
        hidden_size:
            Dimensionality of the local residue representations.
        global_hidden_size:
            Dimensionality of the global protein representation.
        annotation_size:
            Number of Gene Ontology annotation channels used by the pretraining objective.
        num_hidden_layers:
            Number of ProteinBERT local/global encoder blocks.
        num_attention_heads:
            Number of global-attention heads in each encoder block.
        attention_key_size:
            Dimensionality of each global-attention query/key head.
        conv_kernel_size:
            Width of the local convolution kernels.
        wide_conv_dilation_rate:
            Dilation rate of the wide local convolution branch.
        hidden_act:
            Non-linear activation function used by dense and convolutional branches.
        initializer_range:
            Standard deviation used by common prediction heads.
        layer_norm_eps:
            Epsilon used by layer normalization layers.
        head:
            The configuration of the downstream prediction head.
        lm_head:
            The configuration of the masked language model head.

    Examples:
        >>> from multimolecule import ProteinBertConfig, ProteinBertModel
        >>> configuration = ProteinBertConfig()
        >>> model = ProteinBertModel(configuration)
        >>> configuration = model.config
    """

    model_type = "proteinbert"

    def __init__(
        self,
        vocab_size: int = 37,
        hidden_size: int = 128,
        global_hidden_size: int = 512,
        annotation_size: int = 8943,
        num_hidden_layers: int = 6,
        num_attention_heads: int = 4,
        attention_key_size: int = 64,
        conv_kernel_size: int = 9,
        wide_conv_dilation_rate: int = 5,
        hidden_act: str = "gelu",
        initializer_range: float = 0.02,
        layer_norm_eps: float = 1.0e-3,
        pad_token_id: int = 0,
        bos_token_id: int = 1,
        eos_token_id: int = 2,
        unk_token_id: int = 3,
        mask_token_id: int = 4,
        null_token_id: int = 5,
        head: HeadConfig | None = None,
        lm_head: MaskedLMHeadConfig | None = None,
        **kwargs,
    ):
        kwargs.setdefault("tie_word_embeddings", False)
        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            unk_token_id=unk_token_id,
            mask_token_id=mask_token_id,
            null_token_id=null_token_id,
            **kwargs,
        )
        if global_hidden_size % num_attention_heads != 0:
            raise ValueError(
                "global_hidden_size must be divisible by num_attention_heads; got "
                f"{global_hidden_size} and {num_attention_heads}."
            )

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.global_hidden_size = global_hidden_size
        self.annotation_size = annotation_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.attention_key_size = attention_key_size
        self.conv_kernel_size = conv_kernel_size
        self.wide_conv_dilation_rate = wide_conv_dilation_rate
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.head = HeadConfig(**head) if head is not None else None
        self.lm_head = (
            MaskedLMHeadConfig(**lm_head)
            if lm_head is not None
            else MaskedLMHeadConfig(transform=None, transform_act=None, bias=True)
        )

ProteinBertForMaskedLM

Bases: ProteinBertPreTrainedModel

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertForMaskedLM
>>> config = ProteinBertConfig()
>>> model = ProteinBertForMaskedLM(config)
>>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
>>> output = model(input_ids, labels=input_ids)
>>> output["logits"].shape
torch.Size([1, 10, 37])
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
Python
class ProteinBertForMaskedLM(ProteinBertPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import ProteinBertConfig, ProteinBertForMaskedLM
        >>> config = ProteinBertConfig()
        >>> model = ProteinBertForMaskedLM(config)
        >>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
        >>> output = model(input_ids, labels=input_ids)
        >>> output["logits"].shape
        torch.Size([1, 10, 37])
    """

    _tied_weights_keys = {
        "lm_head.decoder.bias": "lm_head.bias",
    }

    def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict:
        tied_weights = super().get_expanded_tied_weights_keys(all_submodels=all_submodels)
        if all_submodels:
            return tied_weights
        return tied_weights | self._tied_weights_keys

    def __init__(self, config: ProteinBertConfig):
        super().__init__(config)
        self.model = ProteinBertModel(config)
        self.lm_head = MaskedLMHead(config)

        # Initialize weights and apply final processing
        self.post_init()

    def get_output_embeddings(self):
        return self.lm_head.decoder

    def set_output_embeddings(self, embeddings):
        self.lm_head.decoder = embeddings
        if hasattr(self.lm_head, "bias"):
            self.lm_head.bias = embeddings.bias

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        annotations: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[Tensor, ...] | MaskedLMOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            annotations=annotations,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        output = self.lm_head(outputs, labels)
        logits, loss = output.logits, output.loss

        return MaskedLMOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

ProteinBertForPreTraining

Bases: ProteinBertPreTrainedModel

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertForPreTraining
>>> config = ProteinBertConfig()
>>> model = ProteinBertForPreTraining(config)
>>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
>>> output = model(input_ids)
>>> output["logits"].shape
torch.Size([1, 10, 37])
>>> output["annotation_logits"].shape
torch.Size([1, 8943])
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
Python
class ProteinBertForPreTraining(ProteinBertPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import ProteinBertConfig, ProteinBertForPreTraining
        >>> config = ProteinBertConfig()
        >>> model = ProteinBertForPreTraining(config)
        >>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
        >>> output = model(input_ids)
        >>> output["logits"].shape
        torch.Size([1, 10, 37])
        >>> output["annotation_logits"].shape
        torch.Size([1, 8943])
    """

    _tied_weights_keys = {
        "lm_head.decoder.bias": "lm_head.bias",
    }

    def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict:
        tied_weights = super().get_expanded_tied_weights_keys(all_submodels=all_submodels)
        if all_submodels:
            return tied_weights
        return tied_weights | self._tied_weights_keys

    def __init__(self, config: ProteinBertConfig):
        super().__init__(config)
        self.model = ProteinBertModel(config)
        self.lm_head = MaskedLMHead(config)
        self.annotation_head = ProteinBertAnnotationPredictionHead(config)

        # Initialize weights and apply final processing
        self.post_init()

    def get_output_embeddings(self):
        return self.lm_head.decoder

    def set_output_embeddings(self, embeddings):
        self.lm_head.decoder = embeddings
        if hasattr(self.lm_head, "bias"):
            self.lm_head.bias = embeddings.bias

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        annotations: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        annotation_labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[Tensor, ...] | ProteinBertForPreTrainingOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            annotations=annotations,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        lm_output = self.lm_head(outputs, labels)
        annotation_logits = self.annotation_head(outputs.pooler_output)

        loss = lm_output.loss
        if annotation_labels is not None:
            annotation_loss = F.binary_cross_entropy_with_logits(annotation_logits, annotation_labels.float())
            loss = annotation_loss if loss is None else loss + annotation_loss

        return ProteinBertForPreTrainingOutput(
            loss=loss,
            logits=lm_output.logits,
            annotation_logits=annotation_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

ProteinBertForPreTrainingOutput dataclass

Bases: ModelOutput

Output type of [ProteinBertForPreTraining].

Parameters:

Name Type Description Default

loss

FloatTensor | None

Masked language modeling plus annotation prediction loss.

None

logits

FloatTensor | None

Prediction scores of the language modeling head.

None

annotation_logits

FloatTensor | None

Prediction scores of the Gene Ontology annotation head.

None

hidden_states

tuple[FloatTensor, ...] | None

Hidden states of the local representation stack.

None

attentions

tuple[FloatTensor, ...] | None

Global-attention probabilities for each layer.

None
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
Python
@dataclass
class ProteinBertForPreTrainingOutput(ModelOutput):
    """
    Output type of [`ProteinBertForPreTraining`].

    Args:
        loss:
            Masked language modeling plus annotation prediction loss.
        logits:
            Prediction scores of the language modeling head.
        annotation_logits:
            Prediction scores of the Gene Ontology annotation head.
        hidden_states:
            Hidden states of the local representation stack.
        attentions:
            Global-attention probabilities for each layer.
    """

    loss: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    annotation_logits: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None
    attentions: tuple[torch.FloatTensor, ...] | None = None

ProteinBertForSequencePrediction

Bases: ProteinBertPreTrainedModel

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertForSequencePrediction
>>> config = ProteinBertConfig()
>>> model = ProteinBertForSequencePrediction(config)
>>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
>>> output = model(input_ids, labels=torch.tensor([[1]]))
>>> output["logits"].shape
torch.Size([1, 1])
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
Python
class ProteinBertForSequencePrediction(ProteinBertPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import ProteinBertConfig, ProteinBertForSequencePrediction
        >>> config = ProteinBertConfig()
        >>> model = ProteinBertForSequencePrediction(config)
        >>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
        >>> output = model(input_ids, labels=torch.tensor([[1]]))
        >>> output["logits"].shape
        torch.Size([1, 1])
    """

    def __init__(self, config: ProteinBertConfig):
        super().__init__(config)
        self.model = ProteinBertModel(config)
        head_config = HeadConfig(config.head or {})
        if head_config.hidden_size is None:
            # ProteinBert exposes two feature streams of different width: the per-token `last_hidden_state`
            # (hidden_size) and the global `pooler_output` (global_hidden_size). Sequence heads read the pooled
            # stream by default, so any output other than `last_hidden_state` resolves to global_hidden_size.
            head_config.hidden_size = (
                config.hidden_size if head_config.output_name == "last_hidden_state" else config.global_hidden_size
            )
        self.sequence_head = SequencePredictionHead(config, head_config)
        self.head_config = self.sequence_head.config

        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        annotations: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[Tensor, ...] | SequencePredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            annotations=annotations,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        output = self.sequence_head(outputs, labels)
        logits, loss = output.logits, output.loss

        return SequencePredictorOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

ProteinBertForTokenPrediction

Bases: ProteinBertPreTrainedModel

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertForTokenPrediction
>>> config = ProteinBertConfig()
>>> model = ProteinBertForTokenPrediction(config)
>>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
>>> output = model(input_ids, labels=torch.randint(2, (1, 8)))
>>> output["logits"].shape
torch.Size([1, 8, 1])
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
Python
class ProteinBertForTokenPrediction(ProteinBertPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import ProteinBertConfig, ProteinBertForTokenPrediction
        >>> config = ProteinBertConfig()
        >>> model = ProteinBertForTokenPrediction(config)
        >>> input_ids = torch.tensor([[1, 16, 23, 15, 21, 18, 6, 9, 14, 2]])
        >>> output = model(input_ids, labels=torch.randint(2, (1, 8)))
        >>> output["logits"].shape
        torch.Size([1, 8, 1])
    """

    def __init__(self, config: ProteinBertConfig):
        super().__init__(config)
        self.model = ProteinBertModel(config)
        self.token_head = TokenPredictionHead(config)
        self.head_config = self.token_head.config

        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        annotations: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[Tensor, ...] | TokenPredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            annotations=annotations,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        output = self.token_head(outputs, attention_mask, input_ids, labels)
        logits, loss = output.logits, output.loss

        return TokenPredictorOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

ProteinBertModel

Bases: ProteinBertPreTrainedModel

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import ProteinBertConfig, ProteinBertModel, ProteinTokenizer
>>> config = ProteinBertConfig()
>>> model = ProteinBertModel(config)
>>> tokenizer = ProteinTokenizer.from_pretrained("multimolecule/protein")
>>> input = tokenizer("MVLSPADKT", return_tensors="pt")
>>> output = model(**input)
>>> output["last_hidden_state"].shape
torch.Size([1, 11, 128])
>>> output["pooler_output"].shape
torch.Size([1, 512])
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
Python
class ProteinBertModel(ProteinBertPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import ProteinBertConfig, ProteinBertModel, ProteinTokenizer
        >>> config = ProteinBertConfig()
        >>> model = ProteinBertModel(config)
        >>> tokenizer = ProteinTokenizer.from_pretrained("multimolecule/protein")
        >>> input = tokenizer("MVLSPADKT", return_tensors="pt")
        >>> output = model(**input)
        >>> output["last_hidden_state"].shape
        torch.Size([1, 11, 128])
        >>> output["pooler_output"].shape
        torch.Size([1, 512])
    """

    def __init__(self, config: ProteinBertConfig):
        super().__init__(config)
        self.pad_token_id = config.pad_token_id
        self.gradient_checkpointing = False
        self.embeddings = ProteinBertEmbeddings(config)
        self.encoder = ProteinBertEncoder(config)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    @can_return_tuple
    @merge_with_config_defaults
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        annotations: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[Tensor, ...] | ProteinBertModelOutput:
        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
        if attention_mask is None:
            if input_ids is not None and self.pad_token_id is not None:
                attention_mask = input_ids.ne(self.pad_token_id)
            else:
                attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device)
        else:
            attention_mask = attention_mask.to(device=hidden_states.device, dtype=torch.bool)
        hidden_states = hidden_states * attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
        if annotations is None:
            annotations = hidden_states.new_zeros(hidden_states.shape[0], self.config.annotation_size)
        annotations = annotations.to(device=hidden_states.device, dtype=hidden_states.dtype)
        global_states = self.embeddings.project_annotations(annotations)

        encoder_outputs = self.encoder(
            hidden_states,
            global_states,
            attention_mask=attention_mask,
            output_hidden_states=kwargs.get("output_hidden_states", self.config.output_hidden_states),
            output_attentions=kwargs.get("output_attentions", self.config.output_attentions),
        )

        return ProteinBertModelOutput(
            last_hidden_state=encoder_outputs.last_hidden_state,
            pooler_output=encoder_outputs.pooler_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

ProteinBertModelOutput dataclass

Bases: ModelOutput

Base class for ProteinBERT backbone outputs.

Parameters:

Name Type Description Default

last_hidden_state

FloatTensor | None

Local residue representations of shape (batch_size, sequence_length, hidden_size).

None

pooler_output

FloatTensor | None

Global protein representations of shape (batch_size, global_hidden_size).

None

hidden_states

tuple[FloatTensor, ...] | None

Hidden states of the local representation stack.

None

attentions

tuple[FloatTensor, ...] | None

Global-attention probabilities for each layer.

None
Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
Python
@dataclass
class ProteinBertModelOutput(ModelOutput):
    """
    Base class for ProteinBERT backbone outputs.

    Args:
        last_hidden_state:
            Local residue representations of shape `(batch_size, sequence_length, hidden_size)`.
        pooler_output:
            Global protein representations of shape `(batch_size, global_hidden_size)`.
        hidden_states:
            Hidden states of the local representation stack.
        attentions:
            Global-attention probabilities for each layer.
    """

    last_hidden_state: torch.FloatTensor | None = None
    pooler_output: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None
    attentions: tuple[torch.FloatTensor, ...] | None = None

ProteinBertPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/proteinbert/modeling_proteinbert.py
Python
class ProteinBertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = ProteinBertConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["ProteinBertLayer"]

    @torch.no_grad()
    def _init_weights(self, module: nn.Module):
        std = self.config.initializer_range
        if isinstance(module, (nn.Linear, nn.Conv1d)):
            init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            init.normal_(module.weight, mean=0.0, std=std)
            if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            init.ones_(module.weight)
            init.zeros_(module.bias)
        elif isinstance(module, ProteinBertGlobalAttention):
            init.xavier_uniform_(module.query)
            init.xavier_uniform_(module.key)
            init.xavier_uniform_(module.value)