ProGen2
ProGen2
Pre-trained model on protein sequences using a causal language modeling (CLM) objective.
Disclaimer
This is an UNOFFICIAL implementation of the ProGen2: Exploring the Boundaries of Protein Language Models by Erik Nijkamp, Jeffrey A. Ruffolo, et al.
The OFFICIAL repository of ProGen2 is at enijkamp/progen.
Tip
The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.
The team releasing ProGen2 did not write this model card for this model so this model card has been written by the MultiMolecule team.
Model Details
ProGen2 is a GPT-J-style model pre-trained on a large corpus of protein sequences in a self-supervised fashion. This means that the model was trained on the raw amino acids of protein sequences only, with an automatic process to generate inputs and labels from those sequences. Please refer to the Training Details section for more information on the training process.
Variants
Model Specification
| Variants |
Num Layers |
Hidden Size |
Num Heads |
Intermediate Size |
Num Parameters (M) |
FLOPs (G) |
MACs (G) |
Max Num Tokens |
| ProGen2-xlarge |
32 |
4096 |
16 |
16384 |
6443.66 |
6735.76 |
3367.27 |
1024 |
| ProGen2-large |
2560 |
32 |
10240 |
2517.34 |
2664.21 |
1331.45 |
| ProGen2-bfd90 |
| ProGen2-base |
27 |
1536 |
16 |
6144 |
764.81 |
826.85 |
413.12 |
2048 |
| ProGen2-oas |
1024 |
| ProGen2-medium |
| ProGen2-small |
12 |
1024 |
4096 |
151.15 |
167.74 |
83.75 |
Links
Usage
The model file depends on the multimolecule library. You can install it using pip:
| Bash |
|---|
| pip install multimolecule
|
Direct Use
Text Generation
You can use this model directly with a pipeline for text generation:
| Python |
|---|
| import multimolecule # you must import multimolecule to register models
from transformers import pipeline
generator = pipeline("text-generation", model="multimolecule/progen2-base")
output = generator("MGHGVSRPPVVTLR", max_new_tokens=50)
|
Downstream Use
Here is how to use this model to get the features of a given sequence in PyTorch:
| Python |
|---|
| from multimolecule import ProteinTokenizer, ProGen2Model
tokenizer = ProteinTokenizer.from_pretrained("multimolecule/progen2-base")
model = ProGen2Model.from_pretrained("multimolecule/progen2-base")
text = "MGHGVSRPPVVTLRPAVLDDCPVLWR"
input = tokenizer(text, return_tensors="pt")
output = model(**input)
|
Sequence Classification / Regression
Note
This model is not fine-tuned for any specific task. You will need to fine-tune the model on a downstream task to use it for sequence classification or regression.
Here is how to use this model as backbone to fine-tune for a sequence-level task in PyTorch:
| Python |
|---|
| import torch
from multimolecule import ProteinTokenizer, ProGen2ForSequencePrediction
tokenizer = ProteinTokenizer.from_pretrained("multimolecule/progen2-base")
model = ProGen2ForSequencePrediction.from_pretrained("multimolecule/progen2-base")
text = "MGHGVSRPPVVTLRPAVLDDCPVLWR"
input = tokenizer(text, return_tensors="pt")
label = torch.tensor([1])
output = model(**input, labels=label)
|
Token Classification / Regression
Note
This model is not fine-tuned for any specific task. You will need to fine-tune the model on a downstream task to use it for token classification or regression.
Here is how to use this model as backbone to fine-tune for a residue-level task in PyTorch:
| Python |
|---|
| import torch
from multimolecule import ProteinTokenizer, ProGen2ForTokenPrediction
tokenizer = ProteinTokenizer.from_pretrained("multimolecule/progen2-base")
model = ProGen2ForTokenPrediction.from_pretrained("multimolecule/progen2-base")
text = "MGHGVSRPPVVTLRPAVLDDCPVLWR"
input = tokenizer(text, return_tensors="pt")
label = torch.randint(2, (len(text), ))
output = model(**input, labels=label)
|
Training Details
ProGen2 used Causal Language Modeling (CLM) as the pre-training objective: given a protein sequence, the model is trained to predict the next amino acid token autoregressively.
Training Data
The ProGen2 models were pre-trained on protein sequence databases:
- Uniref90: A clustered version of the UniProt database at 90% sequence identity, containing approximately 135 million sequences.
- BFD30: The Big Fantastic Database clustered at 30% sequence identity, approximately one-third the size of Uniref90.
- BFD90: The Big Fantastic Database clustered at 90% sequence identity, approximately twice the size of Uniref90.
- OAS: The Observed Antibody Space database, clustered at 85% sequence identity.
Different model variants were trained on different combinations:
- progen2-small, progen2-medium, progen2-base, progen2-large, progen2-xlarge: Trained on Uniref90 and BFD30.
- progen2-bfd90: Trained on Uniref90 and BFD90.
- progen2-oas: Trained on the OAS database.
Training Procedure
ProGen2 used causal language modeling (CLM) as the pre-training objective.
Pre-training
The model was trained on Google TPU-v3 pods using JAX.
- Batch size: 500,000 – 1,000,000
- Steps: 350,000 – 400,000
- Optimizer: Adam(β1=0.9, β2=0.999, ε=1e-8)
- Learning rate: 1e-5 – 6e-4
- Learning rate scheduler: Cosine
- Learning rate warm-up: 3,000 – 10,000 steps
- Weight decay: 0.1
- Maximum Gradient Norm: 0.8 – 1.0
Citation
| BibTeX |
|---|
| @ARTICLE{Nijkamp2023-jz,
title = "{ProGen2}: Exploring the boundaries of protein language models",
author = "Nijkamp, Erik and Ruffolo, Jeffrey A and Weinstein, Eli N and
Naik, Nikhil and Madani, Ali",
abstract = "Attention-based models trained on protein sequences have
demonstrated incredible success at classification and generation
tasks relevant for artificial-intelligence-driven protein
design. However, we lack a sufficient understanding of how very
large-scale models and data play a role in effective protein
model development. We introduce a suite of protein language
models, named ProGen2, that are scaled up to 6.4B parameters and
trained on different sequence datasets drawn from over a billion
proteins from genomic, metagenomic, and immune repertoire
databases. ProGen2 models show state-of-the-art performance in
capturing the distribution of observed evolutionary sequences,
generating novel viable sequences, and predicting protein
fitness without additional fine-tuning. As large model sizes and
raw numbers of protein sequences continue to become more widely
accessible, our results suggest that a growing emphasis needs to
be placed on the data distribution provided to a protein
sequence model. Our models and code are open sourced for
widespread adoption in protein engineering. A record of this
paper's Transparent Peer Review process is included in the
supplemental information.",
journal = "Cell Syst.",
publisher = "Elsevier BV",
volume = 14,
number = 11,
pages = "968--978.e3",
month = nov,
year = 2023,
keywords = "fitness prediction; language modeling; protein design",
copyright = "http://www.elsevier.com/open-access/userlicense/1.0/",
language = "en"
}
|
Note
The artifacts distributed in this repository are part of the MultiMolecule project.
If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:
| BibTeX |
|---|
| @software{chen_2024_12638419,
author = {Chen, Zhiyuan and Zhu, Sophia Y.},
title = {MultiMolecule},
doi = {10.5281/zenodo.12638419},
publisher = {Zenodo},
url = {https://doi.org/10.5281/zenodo.12638419},
year = 2024,
month = may,
day = 4
}
|
Please use GitHub issues of MultiMolecule for any questions or comments on the model card.
Please contact the authors of the ProGen2 paper for questions or comments on the paper/model.
License
This model implementation is licensed under the GNU Affero General Public License.
For additional terms and clarifications, please refer to our License FAQ.
| Text Only |
|---|
| SPDX-License-Identifier: AGPL-3.0-or-later
|
multimolecule.models.progen2
ProteinTokenizer
Bases: Tokenizer
Tokenizer for Protein sequences.
Parameters:
| Name |
Type |
Description |
Default |
alphabet
|
Alphabet | str | List[str] | None
|
alphabet to use for tokenization.
- If is
None, the standard RNA alphabet will be used.
- If is a
string, it should correspond to the name of a predefined alphabet. The options include
standard
iupac
streamline
- If is an alphabet or a list of characters, that specific alphabet will be used.
|
None
|
do_upper_case
|
bool
|
Whether to convert input to uppercase.
|
True
|
Examples:
| Python Console Session |
|---|
| >>> from multimolecule import ProteinTokenizer
>>> tokenizer = ProteinTokenizer()
>>> tokenizer('ACDEFGHIKLMNPQRSTVWYXZBJUO')["input_ids"]
[1, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 2]
>>> tokenizer('<pad><cls><eos><unk><mask><null>|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 32, 33, 34, 35, 36, 2]
>>> tokenizer('manlgcwmlv')["input_ids"]
[1, 16, 6, 17, 15, 11, 7, 24, 16, 15, 23, 2]
|
Source code in multimolecule/tokenisers/protein/tokenization_protein.py
| Python |
|---|
| class ProteinTokenizer(Tokenizer):
"""
Tokenizer for Protein sequences.
Args:
alphabet: alphabet to use for tokenization.
- If is `None`, the standard RNA alphabet will be used.
- If is a `string`, it should correspond to the name of a predefined alphabet. The options include
+ `standard`
+ `iupac`
+ `streamline`
- If is an alphabet or a list of characters, that specific alphabet will be used.
do_upper_case: Whether to convert input to uppercase.
Examples:
>>> from multimolecule import ProteinTokenizer
>>> tokenizer = ProteinTokenizer()
>>> tokenizer('ACDEFGHIKLMNPQRSTVWYXZBJUO')["input_ids"]
[1, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 2]
>>> tokenizer('<pad><cls><eos><unk><mask><null>|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 32, 33, 34, 35, 36, 2]
>>> tokenizer('manlgcwmlv')["input_ids"]
[1, 16, 6, 17, 15, 11, 7, 24, 16, 15, 23, 2]
"""
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
alphabet: Alphabet | str | List[str] | None = None,
do_upper_case: bool = True,
additional_special_tokens: List | Tuple | None = None,
**kwargs,
):
if not isinstance(alphabet, Alphabet):
alphabet = get_alphabet(alphabet)
super().__init__(
alphabet=alphabet,
additional_special_tokens=additional_special_tokens,
do_upper_case=do_upper_case,
**kwargs,
)
def _tokenize(self, text: str, **kwargs):
if self.do_upper_case:
text = text.upper()
return list(text)
|
ProGen2Config
Bases: PreTrainedConfig
This is the configuration class to store the configuration of a ProGen2Model.
It is used to instantiate a ProGen2 model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the ProGen2
salesforce/progen2 architecture, which follows the GPT-J style transformer.
Configuration objects inherit from PreTrainedConfig and can be used to
control the model outputs. Read the documentation from PreTrainedConfig
for more information.
Parameters:
| Name |
Type |
Description |
Default |
vocab_size
|
int
|
Vocabulary size of the ProGen2 model. Defines the number of different tokens that can be represented by the
inputs_ids passed when calling [ProGen2Model].
|
35
|
hidden_size
|
int
|
Dimensionality of the encoder layers and the pooler layer.
|
1536
|
num_hidden_layers
|
int
|
Number of hidden layers in the Transformer encoder.
|
27
|
num_attention_heads
|
int
|
Number of attention heads for each attention layer in the Transformer encoder.
|
16
|
|
|
int | None
|
Dimensionality of the “intermediate” (often named feed-forward) layer in the Transformer encoder.
|
None
|
hidden_act
|
str
|
The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu",
"relu", "silu" and "gelu_new" are supported.
|
'gelu_new'
|
embedding_dropout
|
float
|
The dropout probability for the embedding layer.
|
0.0
|
hidden_dropout
|
float
|
The dropout probability for residual connections and fully connected layers in the decoder.
|
0.0
|
attention_dropout
|
float
|
The dropout ratio for the attention probabilities.
|
0.0
|
max_position_embeddings
|
int
|
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
|
2048
|
initializer_range
|
float
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
0.02
|
layer_norm_eps
|
float
|
The epsilon used by the layer normalization layers.
|
1e-05
|
rotary_dim
|
int | None
|
Dimensionality of rotary position embeddings. If None, rotary embeddings are applied across the full
head dimension.
|
48
|
scale_attn_weights
|
bool
|
Whether to scale attention weights by sqrt(head_dim).
|
True
|
use_cache
|
bool
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if config.is_decoder=True.
|
True
|
is_decoder
|
bool
|
Whether the model is used as a decoder or not. If False, the model is used as an encoder.
|
True
|
Examples:
| Python Console Session |
|---|
| >>> from multimolecule import ProGen2Config, ProGen2Model
>>> # Initializing a ProGen2 multimolecule/progen2 style configuration
>>> configuration = ProGen2Config()
>>> # Initializing a model (with random weights) from the multimolecule/progen2 style configuration
>>> model = ProGen2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
|
Source code in multimolecule/models/progen2/configuration_progen2.py
| Python |
|---|
| class ProGen2Config(PreTrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ProGen2Model`][multimolecule.models.ProGen2Model].
It is used to instantiate a ProGen2 model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the ProGen2
[salesforce/progen2](https://github.com/salesforce/progen) architecture, which follows the GPT-J style transformer.
Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
for more information.
Args:
vocab_size:
Vocabulary size of the ProGen2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`ProGen2Model`].
hidden_size:
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers:
Number of hidden layers in the Transformer encoder.
num_attention_heads:
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size:
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act:
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
embedding_dropout:
The dropout probability for the embedding layer.
hidden_dropout:
The dropout probability for residual connections and fully connected layers in the decoder.
attention_dropout:
The dropout ratio for the attention probabilities.
max_position_embeddings:
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range:
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps:
The epsilon used by the layer normalization layers.
rotary_dim:
Dimensionality of rotary position embeddings. If `None`, rotary embeddings are applied across the full
head dimension.
scale_attn_weights:
Whether to scale attention weights by sqrt(head_dim).
use_cache:
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
is_decoder:
Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
Examples:
>>> from multimolecule import ProGen2Config, ProGen2Model
>>> # Initializing a ProGen2 multimolecule/progen2 style configuration
>>> configuration = ProGen2Config()
>>> # Initializing a model (with random weights) from the multimolecule/progen2 style configuration
>>> model = ProGen2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "progen2"
def __init__(
self,
vocab_size: int = 35,
hidden_size: int = 1536,
num_hidden_layers: int = 27,
num_attention_heads: int = 16,
intermediate_size: int | None = None,
hidden_act: str = "gelu_new",
embedding_dropout: float = 0.0,
hidden_dropout: float = 0.0,
attention_dropout: float = 0.0,
max_position_embeddings: int = 2048,
initializer_range: float = 0.02,
layer_norm_eps: float = 1e-5,
rotary_dim: int | None = 48,
scale_attn_weights: bool = True,
use_cache: bool = True,
is_decoder: bool = True,
**kwargs,
):
kwargs.setdefault("tie_word_embeddings", False)
kwargs.setdefault("null_token_id", None)
super().__init__(**kwargs)
validate_attention_dimensions(hidden_size, num_attention_heads)
head_dim = hidden_size // num_attention_heads
if rotary_dim is not None and rotary_dim > head_dim:
raise ValueError(
f"rotary_dim ({rotary_dim}) must be <= head_dim " f"({head_dim} = hidden_size // num_attention_heads)."
)
if intermediate_size is None:
intermediate_size = 4 * hidden_size
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.embedding_dropout = embedding_dropout
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.rotary_dim = rotary_dim
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.is_decoder = is_decoder
|
ProGen2ForCausalLM
Bases: ProGen2PreTrainedModel, GenerationMixin
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import ProGen2Config, ProGen2ForCausalLM
>>> config = ProGen2Config()
>>> model = ProGen2ForCausalLM(config)
|
Source code in multimolecule/models/progen2/modeling_progen2.py
| Python |
|---|
| class ProGen2ForCausalLM(ProGen2PreTrainedModel, GenerationMixin):
"""
Examples:
>>> import torch
>>> from multimolecule import ProGen2Config, ProGen2ForCausalLM
>>> config = ProGen2Config()
>>> model = ProGen2ForCausalLM(config)
"""
def __init__(self, config: ProGen2Config):
super().__init__(config)
self.model = ProGen2Model(config, add_pooling_layer=False)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.model.embeddings.word_embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@can_return_tuple
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
labels: Tensor | None = None,
use_cache: bool | None = None,
cache_position: torch.LongTensor | None = None,
logits_to_keep: int | Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits; the loss path casts to float32 for cross-entropy stability
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
lm_logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=lm_logits.to(torch.float32),
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|
ProGen2ForSequencePrediction
Bases: ProGen2PreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import ProGen2Config, ProGen2ForSequencePrediction
>>> config = ProGen2Config()
>>> model = ProGen2ForSequencePrediction(config)
|
Source code in multimolecule/models/progen2/modeling_progen2.py
| Python |
|---|
| class ProGen2ForSequencePrediction(ProGen2PreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import ProGen2Config, ProGen2ForSequencePrediction
>>> config = ProGen2Config()
>>> model = ProGen2ForSequencePrediction(config)
"""
def __init__(self, config: ProGen2Config):
super().__init__(config)
self.model = ProGen2Model(config)
self.sequence_head = SequencePredictionHead(config)
self.head_config = self.sequence_head.config
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
position_ids: torch.LongTensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
labels: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> SequencePredictorOutput:
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
output = self.sequence_head(outputs, labels)
logits, loss = output.logits, output.loss
return SequencePredictorOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|
ProGen2ForTokenPrediction
Bases: ProGen2PreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import ProGen2Config, ProGen2ForTokenPrediction
>>> config = ProGen2Config()
>>> model = ProGen2ForTokenPrediction(config)
|
Source code in multimolecule/models/progen2/modeling_progen2.py
| Python |
|---|
| class ProGen2ForTokenPrediction(ProGen2PreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import ProGen2Config, ProGen2ForTokenPrediction
>>> config = ProGen2Config()
>>> model = ProGen2ForTokenPrediction(config)
"""
def __init__(self, config: ProGen2Config):
super().__init__(config)
self.model = ProGen2Model(config, add_pooling_layer=False)
self.token_head = TokenPredictionHead(config)
self.head_config = self.token_head.config
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
position_ids: torch.LongTensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
labels: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> TokenPredictorOutput:
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
output = self.token_head(outputs, attention_mask, input_ids, labels)
logits, loss = output.logits, output.loss
return TokenPredictorOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|
ProGen2Model
Bases: ProGen2PreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import ProGen2Config, ProGen2Model
>>> config = ProGen2Config()
>>> model = ProGen2Model(config)
|
Source code in multimolecule/models/progen2/modeling_progen2.py
| Python |
|---|
| class ProGen2Model(ProGen2PreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import ProGen2Config, ProGen2Model
>>> config = ProGen2Config()
>>> model = ProGen2Model(config)
"""
def __init__(self, config: ProGen2Config, add_pooling_layer: bool = True):
super().__init__(config)
self.pad_token_id = config.pad_token_id
self.gradient_checkpointing = False
self.embeddings = ProGen2Embeddings(config)
self.decoder = ProGen2Decoder(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.rotary_emb = ProGen2RotaryEmbedding(config=config)
self.pooler = ProGen2Pooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
@merge_with_config_defaults
@capture_outputs
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
use_cache: bool | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[Tensor, ...] | BaseModelOutputWithPoolingAndCrossAttentions:
if isinstance(input_ids, NestedTensor):
input_ids, attention_mask = input_ids.tensor, input_ids.mask
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
else:
inputs_embeds = self.embeddings(inputs_embeds=inputs_embeds)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_embeddings = self.rotary_emb(inputs_embeds, position_ids=position_ids)
decoder_outputs = self.decoder(
inputs_embeds,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.layer_norm(decoder_outputs.last_hidden_state)
pooled_output = self.pooler(hidden_states) if self.pooler is not None else None
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=hidden_states,
pooler_output=pooled_output,
past_key_values=decoder_outputs.past_key_values,
)
|