Xpresso
Deep convolutional neural network for predicting mRNA abundance directly from genomic promoter sequence.
Disclaimer
This is an UNOFFICIAL implementation of Predicting mRNA Abundance Directly from Genomic Sequence Using Deep Convolutional Neural Networks by Vikram Agarwal, et al.
The OFFICIAL repository of Xpresso is at vagarwal87/Xpresso.
Tip
The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.
The team releasing Xpresso did not write this model card for this model so this model card has been written by the MultiMolecule team.
Model Details
Xpresso is a deep convolutional neural network (CNN) that predicts steady-state mRNA expression level directly from genomic sequence. It consumes a promoter window of roughly 10.5 kb centered on the transcription start site (TSS), processes it through a stack of 1D convolution + max-pooling blocks, flattens the result, concatenates a small set of auxiliary numeric mRNA half-life features, and passes the combined representation through fully-connected layers to predict a single scalar expression value. Please refer to the Training Details section for more information on the training process.
Model Specification
| Input Length |
Conv Blocks |
Hidden Size |
Auxiliary Features |
Num Parameters (M) |
FLOPs (G) |
MACs (G) |
Max Num Tokens |
| 10,500 |
2 |
2 |
6 |
0.11 |
0.11 |
0.05 |
10,500 |
Links
Usage
The model file depends on the multimolecule library. You can install it using pip:
| Bash |
|---|
| pip install multimolecule
|
Direct Use
mRNA Expression Prediction
You can use this model directly to predict the mRNA expression of a promoter sequence together with its auxiliary mRNA half-life features:
| Python |
|---|
| >>> import torch
>>> from multimolecule import DnaTokenizer, XpressoForSequencePrediction
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/xpresso")
>>> model = XpressoForSequencePrediction.from_pretrained("multimolecule/xpresso")
>>> input = tokenizer("ACGTACGTACGTACGT", return_tensors="pt")
>>> features = torch.randn(1, model.config.num_features)
>>> output = model(**input, features=features)
>>> output.logits.shape
torch.Size([1, 1])
|
The auxiliary half-life features are passed through the features argument as a float tensor of shape (batch_size, num_features). Models configured with a non-zero num_features require this tensor; models configured with num_features=0 do not accept it.
Interface
- Input length: fixed 10,500 bp promoter window centered on the TSS
- Padding: shorter inputs right-padded; longer inputs center-cropped to
input_length
- Auxiliary inputs:
features tensor of shape (batch_size, num_features) required when num_features > 0; not accepted when num_features = 0
- Output: scalar mRNA expression
Training Details
Xpresso was trained to predict steady-state mRNA expression levels (median across tissues/cell lines) from genomic promoter sequence.
Training Data
Xpresso was trained on human and mouse genes, using promoter sequences (~10.5 kb windows centered on the TSS) together with mRNA half-life features derived from gene-body and UTR properties. Expression targets are log-transformed median mRNA levels across tissues.
The Xpresso model follows the published humanMedian configuration.
Training Procedure
Pre-training
The model was trained to minimize a mean-squared-error loss between predicted and observed log mRNA expression values.
- Optimizer: Adam
- Loss: Mean squared error
Citation
| BibTeX |
|---|
| @article{agarwal2020predicting,
author = {Agarwal, Vikram and Shendure, Jay},
journal = {Cell Reports},
number = 7,
pages = {107663},
publisher = {Elsevier BV},
title = {Predicting mRNA Abundance Directly from Genomic Sequence Using Deep Convolutional Neural Networks},
volume = 31,
year = 2020,
doi = {10.1016/j.celrep.2020.107663}
}
|
Note
The artifacts distributed in this repository are part of the MultiMolecule project.
If MultiMolecule supports your research, please cite the MultiMolecule project as follows:
| BibTeX |
|---|
| @software{chen_2024_12638419,
author = {Chen, Zhiyuan and Zhu, Sophia Y.},
title = {MultiMolecule},
doi = {10.5281/zenodo.12638419},
publisher = {Zenodo},
url = {https://doi.org/10.5281/zenodo.12638419},
year = 2024,
month = may,
day = 4
}
|
Please use GitHub issues of MultiMolecule for any questions or comments on the model card.
Please contact the authors of the Xpresso paper for questions or comments on the paper/model.
License
This model implementation is licensed under the GNU Affero General Public License.
For additional terms and clarifications, please refer to our License FAQ.
| Text Only |
|---|
| SPDX-License-Identifier: AGPL-3.0-or-later
|
API Reference
XpressoConfig
Bases: PreTrainedConfig
This is the configuration class to store the configuration of a
XpressoModel. It is used to instantiate a Xpresso model according to the
specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a
similar configuration to that of the Xpresso
vagarwal87/Xpresso architecture.
Configuration objects inherit from PreTrainedConfig and can be used to
control the model outputs. Read the documentation from PreTrainedConfig
for more information.
Parameters:
| Name |
Type |
Description |
Default |
vocab_size
|
int
|
Vocabulary size of the Xpresso model. Defines the number of feature channels derived from input_ids for
the first convolution. Defaults to 5.
|
5
|
|
|
int
|
The length of the promoter sequence window (centered on the TSS) consumed by the convolutional stack.
|
10500
|
num_conv_layers
|
int
|
Number of convolutional blocks in the encoder.
|
2
|
conv_channels
|
list[int] | None
|
Number of output channels for each convolutional block. Length must equal num_conv_layers.
|
None
|
conv_kernel_sizes
|
list[int] | None
|
Convolution kernel size for each convolutional block. Length must equal num_conv_layers.
|
None
|
conv_dilations
|
list[int] | None
|
Dilation factor for each convolutional block. Length must equal num_conv_layers.
|
None
|
pool_sizes
|
list[int] | None
|
Max-pooling window for each convolutional block. Length must equal num_conv_layers.
|
None
|
num_features
|
int
|
Number of auxiliary numeric mRNA half-life features concatenated with the convolutional representation
before the fully-connected head.
|
6
|
fc_dims
|
list[int] | None
|
Dimensionality of each fully-connected layer in the head.
|
None
|
hidden_act
|
str
|
The non-linear activation function (function or string) in the encoder and the head. If string, "gelu",
"relu", "silu" and "gelu_new" are supported.
|
'relu'
|
hidden_dropout
|
float
|
The dropout probability applied after each fully-connected layer.
|
0.00099
|
num_labels
|
int
|
Number of output labels. Xpresso predicts a single scalar mRNA expression value.
|
1
|
head
|
HeadConfig | None
|
The configuration of the prediction head. Defaults to a regression head
(problem_type="regression"), matching Xpresso’s mRNA abundance prediction task.
|
None
|
Examples:
| Python Console Session |
|---|
| >>> from multimolecule import XpressoConfig, XpressoModel
>>> # Initializing a Xpresso multimolecule/xpresso style configuration
>>> configuration = XpressoConfig()
>>> # Initializing a model (with random weights) from the multimolecule/xpresso style configuration
>>> model = XpressoModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
|
Source code in multimolecule/models/xpresso/configuration_xpresso.py
| Python |
|---|
| class XpressoConfig(PreTrainedConfig):
r"""
This is the configuration class to store the configuration of a
[`XpressoModel`][multimolecule.models.XpressoModel]. It is used to instantiate a Xpresso model according to the
specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a
similar configuration to that of the Xpresso
[vagarwal87/Xpresso](https://github.com/vagarwal87/Xpresso) architecture.
Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
for more information.
Args:
vocab_size:
Vocabulary size of the Xpresso model. Defines the number of feature channels derived from `input_ids` for
the first convolution. Defaults to 5.
input_length:
The length of the promoter sequence window (centered on the TSS) consumed by the convolutional stack.
num_conv_layers:
Number of convolutional blocks in the encoder.
conv_channels:
Number of output channels for each convolutional block. Length must equal `num_conv_layers`.
conv_kernel_sizes:
Convolution kernel size for each convolutional block. Length must equal `num_conv_layers`.
conv_dilations:
Dilation factor for each convolutional block. Length must equal `num_conv_layers`.
pool_sizes:
Max-pooling window for each convolutional block. Length must equal `num_conv_layers`.
num_features:
Number of auxiliary numeric mRNA half-life features concatenated with the convolutional representation
before the fully-connected head.
fc_dims:
Dimensionality of each fully-connected layer in the head.
hidden_act:
The non-linear activation function (function or string) in the encoder and the head. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
hidden_dropout:
The dropout probability applied after each fully-connected layer.
num_labels:
Number of output labels. Xpresso predicts a single scalar mRNA expression value.
head:
The configuration of the prediction head. Defaults to a regression head
(`problem_type="regression"`), matching Xpresso's mRNA abundance prediction task.
Examples:
>>> from multimolecule import XpressoConfig, XpressoModel
>>> # Initializing a Xpresso multimolecule/xpresso style configuration
>>> configuration = XpressoConfig()
>>> # Initializing a model (with random weights) from the multimolecule/xpresso style configuration
>>> model = XpressoModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "xpresso"
def __init__(
self,
vocab_size: int = 5,
input_length: int = 10500,
num_conv_layers: int = 2,
conv_channels: list[int] | None = None,
conv_kernel_sizes: list[int] | None = None,
conv_dilations: list[int] | None = None,
pool_sizes: list[int] | None = None,
num_features: int = 6,
fc_dims: list[int] | None = None,
hidden_act: str = "relu",
hidden_dropout: float = 0.00099,
num_labels: int = 1,
head: HeadConfig | None = None,
**kwargs,
):
kwargs.setdefault("pad_token_id", vocab_size - 1)
kwargs.setdefault("unk_token_id", vocab_size - 1)
kwargs.setdefault("bos_token_id", None)
kwargs.setdefault("eos_token_id", None)
kwargs.setdefault("mask_token_id", None)
kwargs.setdefault("null_token_id", None)
super().__init__(num_labels=num_labels, **kwargs)
self.vocab_size = vocab_size
self.input_length = input_length
self.num_conv_layers = num_conv_layers
if conv_channels is None:
conv_channels = [128, 32]
if conv_kernel_sizes is None:
conv_kernel_sizes = [6, 9]
if conv_dilations is None:
conv_dilations = [1, 1]
if pool_sizes is None:
pool_sizes = [30, 10]
if fc_dims is None:
fc_dims = [64, 2]
self.conv_channels = conv_channels
self.conv_kernel_sizes = conv_kernel_sizes
self.conv_dilations = conv_dilations
self.pool_sizes = pool_sizes
self.num_features = num_features
self.fc_dims = fc_dims
self.hidden_act = hidden_act
self.hidden_dropout = hidden_dropout
self.num_labels = num_labels
# `hidden_size` is the dimensionality of the pooled representation consumed by
# `SequencePredictionHead`; it equals the width of the last fully-connected layer.
self.hidden_size = self.fc_dims[-1]
if head is None:
head = HeadConfig(problem_type="regression")
else:
head = HeadConfig(head)
if head.problem_type is None:
head.problem_type = "regression"
self.head = head
self._validate()
def _validate(self) -> None:
per_layer = {
"conv_channels": self.conv_channels,
"conv_kernel_sizes": self.conv_kernel_sizes,
"conv_dilations": self.conv_dilations,
"pool_sizes": self.pool_sizes,
}
for name, value in per_layer.items():
if len(value) != self.num_conv_layers:
raise ValueError(
f"`{name}` must have length `num_conv_layers` ({self.num_conv_layers}), got {len(value)}."
)
if self.input_length <= 0:
raise ValueError(f"`input_length` must be positive, got {self.input_length}.")
if self.num_features < 0:
raise ValueError(f"`num_features` must be non-negative, got {self.num_features}.")
if not self.fc_dims:
raise ValueError("`fc_dims` must contain at least one fully-connected dimension.")
|
XpressoForSequencePrediction
Bases: XpressoPreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import XpressoConfig, XpressoForSequencePrediction, DnaTokenizer
>>> config = XpressoConfig()
>>> model = XpressoForSequencePrediction(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/xpresso")
>>> input = tokenizer(["ACGTACGTACGT", "TGCATGCATGCA"], return_tensors="pt")
>>> features = torch.randn(2, config.num_features)
>>> output = model(**input, features=features, labels=torch.randn(2, 1))
>>> output["logits"].shape
torch.Size([2, 1])
|
Source code in multimolecule/models/xpresso/modeling_xpresso.py
| Python |
|---|
| class XpressoForSequencePrediction(XpressoPreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import XpressoConfig, XpressoForSequencePrediction, DnaTokenizer
>>> config = XpressoConfig()
>>> model = XpressoForSequencePrediction(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/xpresso")
>>> input = tokenizer(["ACGTACGTACGT", "TGCATGCATGCA"], return_tensors="pt")
>>> features = torch.randn(2, config.num_features)
>>> output = model(**input, features=features, labels=torch.randn(2, 1))
>>> output["logits"].shape
torch.Size([2, 1])
"""
def __init__(self, config: XpressoConfig):
super().__init__(config)
self.model = XpressoModel(config)
self.sequence_head = SequencePredictionHead(config)
self.head_config = self.sequence_head.config
# Initialize weights and apply final processing
self.post_init()
@property
def output_channels(self) -> list[str]:
if self.config.num_labels == 1:
return ["expression"]
return [f"expression_{index}" for index in range(self.config.num_labels)]
@can_return_tuple
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
features: Tensor | None = None,
labels: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[Tensor, ...] | SequencePredictorOutput:
outputs = self.model(
input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
features=features,
return_dict=True,
**kwargs,
)
output = self.sequence_head(outputs, labels)
logits, loss = output.logits, output.loss
return SequencePredictorOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|
XpressoModel
Bases: XpressoPreTrainedModel
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import XpressoConfig, XpressoModel, DnaTokenizer
>>> config = XpressoConfig()
>>> model = XpressoModel(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/xpresso")
>>> input = tokenizer(["ACGTACGTACGT", "TGCATGCATGCA"], return_tensors="pt")
>>> features = torch.randn(2, config.num_features)
>>> output = model(**input, features=features)
>>> output["pooler_output"].shape
torch.Size([2, 2])
|
Source code in multimolecule/models/xpresso/modeling_xpresso.py
| Python |
|---|
| class XpressoModel(XpressoPreTrainedModel):
"""
Examples:
>>> import torch
>>> from multimolecule import XpressoConfig, XpressoModel, DnaTokenizer
>>> config = XpressoConfig()
>>> model = XpressoModel(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/xpresso")
>>> input = tokenizer(["ACGTACGTACGT", "TGCATGCATGCA"], return_tensors="pt")
>>> features = torch.randn(2, config.num_features)
>>> output = model(**input, features=features)
>>> output["pooler_output"].shape
torch.Size([2, 2])
"""
def __init__(self, config: XpressoConfig):
super().__init__(config)
self.embeddings = XpressoEmbedding(config)
self.encoder = XpressoEncoder(config)
self.head = XpressoHead(config)
# Initialize weights and apply final processing
self.post_init()
# Xpresso's `last_hidden_state` is the *flattened* convolutional representation, not a
# per-position layer output, so it must not be tied into the recorded `hidden_states` tuple.
@merge_with_config_defaults
@capture_outputs(tie_last_hidden_states=False)
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
features: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> XpressoModelOutput:
"""
Args:
input_ids: Token ids of the promoter sequence.
attention_mask: Binary mask; 1 for real tokens, 0 for padding.
inputs_embeds: Pre-computed one-hot (or soft) embeddings. Mutually exclusive with
`input_ids`.
features: Optional auxiliary tensor of shape `(batch_size, config.num_features)`
containing numeric mRNA half-life features (e.g. 3′-UTR length, Kozak score).
Required when `config.num_features > 0`; must be `None` when
`config.num_features == 0`. The tensor is concatenated with the flattened
convolutional representation before the fully-connected head.
"""
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
if input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if isinstance(input_ids, NestedTensor):
if attention_mask is None:
attention_mask = input_ids.mask
input_ids = input_ids.tensor
if isinstance(inputs_embeds, NestedTensor):
if attention_mask is None:
attention_mask = inputs_embeds.mask
inputs_embeds = inputs_embeds.tensor
if input_ids is not None:
batch_size = input_ids.size(0)
else:
if inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size = inputs_embeds.size(0)
self._validate_features(features, batch_size)
embedding_output = self.embeddings(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
)
encoder_outputs = self.encoder(embedding_output, **kwargs)
conv_output = encoder_outputs.last_hidden_state
pooler_output = self.head(conv_output, features=features)
return XpressoModelOutput(
last_hidden_state=conv_output,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=None,
)
def _validate_features(self, features: Tensor | None, batch_size: int) -> None:
if self.config.num_features == 0:
if features is not None:
raise ValueError(
"This Xpresso model is configured with num_features=0 and does not accept a `features` tensor."
)
return
if features is None:
raise ValueError(
f"This Xpresso model is configured with num_features={self.config.num_features}; "
"you must pass the auxiliary `features` tensor."
)
if features.ndim != 2:
raise ValueError(
"`features` must be a 2D tensor of shape "
f"(batch_size, {self.config.num_features}), got shape {tuple(features.shape)}."
)
if features.size(0) != batch_size:
raise ValueError(f"`features` batch size ({features.size(0)}) must match input batch size ({batch_size}).")
if features.size(1) != self.config.num_features:
raise ValueError(
f"`features` last dimension ({features.size(1)}) must equal "
f"`config.num_features` ({self.config.num_features})."
)
|
forward
Parameters:
| Name |
Type |
Description |
Default |
|
|
Tensor | NestedTensor | None
|
Token ids of the promoter sequence.
|
None
|
attention_mask
|
Tensor | None
|
Binary mask; 1 for real tokens, 0 for padding.
|
None
|
|
|
Tensor | NestedTensor | None
|
Pre-computed one-hot (or soft) embeddings. Mutually exclusive with
input_ids.
|
None
|
features
|
Tensor | None
|
Optional auxiliary tensor of shape (batch_size, config.num_features)
containing numeric mRNA half-life features (e.g. 3′-UTR length, Kozak score).
Required when config.num_features > 0; must be None when
config.num_features == 0. The tensor is concatenated with the flattened
convolutional representation before the fully-connected head.
|
None
|
Source code in multimolecule/models/xpresso/modeling_xpresso.py
| Python |
|---|
| @merge_with_config_defaults
@capture_outputs(tie_last_hidden_states=False)
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
features: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> XpressoModelOutput:
"""
Args:
input_ids: Token ids of the promoter sequence.
attention_mask: Binary mask; 1 for real tokens, 0 for padding.
inputs_embeds: Pre-computed one-hot (or soft) embeddings. Mutually exclusive with
`input_ids`.
features: Optional auxiliary tensor of shape `(batch_size, config.num_features)`
containing numeric mRNA half-life features (e.g. 3′-UTR length, Kozak score).
Required when `config.num_features > 0`; must be `None` when
`config.num_features == 0`. The tensor is concatenated with the flattened
convolutional representation before the fully-connected head.
"""
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
if input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if isinstance(input_ids, NestedTensor):
if attention_mask is None:
attention_mask = input_ids.mask
input_ids = input_ids.tensor
if isinstance(inputs_embeds, NestedTensor):
if attention_mask is None:
attention_mask = inputs_embeds.mask
inputs_embeds = inputs_embeds.tensor
if input_ids is not None:
batch_size = input_ids.size(0)
else:
if inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size = inputs_embeds.size(0)
self._validate_features(features, batch_size)
embedding_output = self.embeddings(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
)
encoder_outputs = self.encoder(embedding_output, **kwargs)
conv_output = encoder_outputs.last_hidden_state
pooler_output = self.head(conv_output, features=features)
return XpressoModelOutput(
last_hidden_state=conv_output,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=None,
)
|
XpressoModelOutput
dataclass
Bases: ModelOutput
Base class for outputs of the Xpresso backbone.
Parameters:
| Name |
Type |
Description |
Default |
last_hidden_state
|
`torch.FloatTensor` of shape `(batch_size, flattened_conv_size)`
|
Flattened convolutional representation of the promoter sequence.
|
None
|
pooler_output
|
`torch.FloatTensor` of shape `(batch_size, hidden_size)`
|
Final fully-connected representation, with the auxiliary mRNA half-life features fused in. This is the
tensor consumed by SequencePredictionHead.
|
None
|
attentions
|
always `None`
|
Xpresso is a purely convolutional architecture and has no attention; this field is always None and is
present only for compatibility with the Transformers output convention.
|
None
|
Source code in multimolecule/models/xpresso/modeling_xpresso.py
| Python |
|---|
| @dataclass
class XpressoModelOutput(ModelOutput):
"""
Base class for outputs of the Xpresso backbone.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, flattened_conv_size)`):
Flattened convolutional representation of the promoter sequence.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
Final fully-connected representation, with the auxiliary mRNA half-life features fused in. This is the
tensor consumed by [`SequencePredictionHead`][multimolecule.modules.SequencePredictionHead].
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the embedding output plus one after each convolutional block) of
shape `(batch_size, length, channels)`. Convolutional feature maps recorded along the encoder stack.
attentions (always `None`):
Xpresso is a purely convolutional architecture and has no attention; this field is always `None` and is
present only for compatibility with the Transformers output convention.
"""
last_hidden_state: torch.FloatTensor | None = None
pooler_output: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
attentions: tuple[torch.FloatTensor, ...] | None = None
|
XpressoPreTrainedModel
Bases: PreTrainedModel
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
Source code in multimolecule/models/xpresso/modeling_xpresso.py
| Python |
|---|
| class XpressoPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = XpressoConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_can_record_outputs: dict[str, Any] | None = None
_no_split_modules = ["XpressoBlock"]
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
# Use transformers.initialization wrappers (imported as `init`); they check the
# `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
if isinstance(module, nn.Conv1d):
init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
init.zeros_(module.bias)
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
elif isinstance(module, nn.Linear):
init.kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(module.bias, -bound, bound)
elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm, nn.GroupNorm)):
init.ones_(module.weight)
init.zeros_(module.bias)
|