跳转至

Framepool

Frame-aware pooling convolutional network for predicting mean ribosome load from variable-length 5’UTR sequences.

Disclaimer

This is an UNOFFICIAL implementation of Predicting mean ribosome load for 5’UTR of any length using deep learning by Alexander Karollus, et al.

The OFFICIAL repository of Framepool is at Karollus/5UTR and the published Kipoi wrapper is at kipoi/models.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing Framepool did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

Framepool is a small 1D convolutional network that predicts the mean ribosome load (MRL) of a human 5’ untranslated region from sequence alone. It extends the fixed-length network of Sample et al., 2019 with a frame-aware pooling layer that reverses the sequence to anchor reading frames at the start codon, slices the convolutional feature map into the three reading frames, and applies global max and masked global average pooling per frame. The pooled representation is length-independent and is consumed by a small dense head followed by a per-sub-library scaling regression that recalibrates the prediction across the two training libraries (egfp_unmod_1 and random). Please refer to the Training Details section for more information on the training process.

The released combined_residual model is recommended by the upstream authors for variant effect scoring.

Model Specification

Num Layers Hidden Size Num Parameters (M) FLOPs (G) MACs (G) Max Num Tokens
4 768 0.28 0.05 0.02 unlimited

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Mean Ribosome Load Prediction

You can use this model directly to predict the mean ribosome load of a 5’UTR sequence:

Python
1
2
3
4
5
6
7
8
>>> from multimolecule import RnaTokenizer, FramepoolForSequencePrediction

>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/framepool")
>>> model = FramepoolForSequencePrediction.from_pretrained("multimolecule/framepool")
>>> output = model(**tokenizer("ACGUACGUACGUACGUACGUACGUACGUACGUACGUACGUACGUACGUAC", return_tensors="pt"))

>>> output.keys()
odict_keys(['logits'])

Interface

  • Input length: variable; the upstream MPRA training data is 25-100 nt 5’UTR but the model accepts any length because of frame-aware pooling
  • Alphabet: RNA (A, C, G, U); RnaTokenizer converts T to U; N and other non-canonical tokens are encoded as all-zero columns and ignored by the masked pooling
  • Padding: zero-padding is supported via attention_mask and is excluded from pooling
  • Output: single scalar per sequence — predicted mean ribosome load (logits, shape (batch_size, 1))
  • Auxiliary inputs: optional library_indicator (shape (batch_size, library_size)) selecting one of the two training sub-libraries for the scaling regression. Defaults to the random library, matching the upstream Kipoi variant effect interface

Variant Effect

Framepool supports paired reference/alternative scoring through the optional alternative_input_ids argument:

  • Single sequence (reference only): logits is the predicted mean ribosome load (one scalar per sequence)
  • Reference + alternative: logits is the log2 mean ribosome load fold change log2(MRL_alt / MRL_ref), matching the Kipoi UTRVariantEffectModel.predict_on_batch mrl_fold_change output
  • Reference and alternative sequences are scored independently; both must use the same library_indicator so that the scaling regression cancels out of the fold change
  • For the upstream “shifted-frame” variant effect outputs (shift_1, shift_2), prepend one or two zero columns (or N tokens) to both reference and alternative inputs before scoring, matching the Kipoi loop

Training Details

Framepool was trained on polysome-profiling MPRA data measuring the mean ribosome load of randomized 5’UTR sequences and uses frame-aware pooling so that a single network can score sequences of arbitrary length.

Training Data

Framepool was trained on the eGFP polysome-profiling MPRA libraries of Sample et al., 2019 in HEK293T cells: the fixed-length library (egfp_unmod_1, 50 nt) and the variable-length library (random, 25-100 nt). Approximately 260,000 sequences were used for training, with 20,000 held out for testing; additional validation was performed on endogenous data.

Training Procedure

Pre-training

  • Loss: mean squared error between the predicted and measured mean ribosome load
  • Optimizer: Adam with lr = 1e-3, beta_1 = 0.9, beta_2 = 0.999, epsilon = 1e-8
  • Epochs: 6
  • Mini-batch sampling: the two training libraries are mixed within every batch; a one-hot library indicator is fed to the scaling regression layer so that the network can absorb the library-specific offset

Citation

BibTeX
@article{karollus2021predicting,
  author    = {Karollus, Alexander and Avsec, {\v Z}iga and Gagneur, Julien},
  title     = {Predicting mean ribosome load for 5{\textquoteright}UTR of any length using deep learning},
  journal   = {PLOS Computational Biology},
  volume    = {17},
  number    = {5},
  pages     = {e1008982},
  year      = {2021},
  publisher = {Public Library of Science},
  doi       = {10.1371/journal.pcbi.1008982}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If MultiMolecule supports your research, please cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the Framepool paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

API Reference

FramepoolConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a FramepoolModel. It is used to instantiate a Framepool model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Framepool combined_residual architecture released with the Karollus et al., 2021 paper.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Number of one-hot input channels derived from the MultiMolecule RNA tokenizer. Defaults to 5 (A, C, G, U, N), matching the MultiMolecule RNA streamline alphabet. The upstream checkpoint only learns the four canonical channels, with upstream T exposed as U; the embedding zeroes the N channel, matching the upstream compute_pad_mask semantics.

5

null_channel_id

int | None

Channel index that represents the upstream “no nucleobase” token (N and padding). The embedding zeroes this column so that the upstream pad_mask = sum(one_hot, axis=2) mask correctly identifies padded positions. Set to None to keep all channels.

4

num_conv_layers

int

Number of stacked length-preserving residual convolutions in the encoder.

3

conv_channels

int

Number of output channels for every convolution in the encoder.

128

kernel_size

int | list[int]

Kernel sizes of the encoder convolutions. Either a scalar shared across all layers, or a list with one entry per layer.

7

dilations

int | list[int]

Dilation rates of the encoder convolutions. Either a scalar shared across all layers, or a list with one entry per layer.

1

hidden_act

str

Non-linear activation applied after each encoder convolution.

'relu'

padding

str

Convolution padding mode. same keeps the sequence length; causal left-pads to retain it.

'same'

skip_connections

str

residual adds the input of every conv past the first to its output (the configuration used by the released checkpoint). "" disables skip connections.

'residual'

num_dense_layers

int

Number of fully-connected layers between the frame-pooled representation and the unscaled MRL output.

1

dense_sizes

list[int] | None

Hidden sizes of the fully-connected layers. Length must match num_dense_layers.

None

dense_dropout

float

Dropout probability applied after every fully-connected layer.

0.2

only_max_pool

bool

If True, the frame pooler concatenates only the per-frame global max pooled features (3 vectors). Otherwise it additionally concatenates the masked global average pooled features (6 vectors), matching the released checkpoint.

False

library_size

int

Number of training sub-libraries supported by the scaling regression head. The released checkpoint was trained jointly on the egfp_unmod_1 and random libraries, so library_size = 2.

2

library_index

int

Default training sub-library index used to construct the one-hot library indicator at inference. Matches the random library used for variant effect prediction upstream (Kipoi UTRVariantEffectModel).

1

num_labels

int

Number of scalar outputs of the model. Framepool predicts a single scalar mean ribosome load value.

1

head

HeadConfig | None

Configuration of the [FramepoolForSequencePrediction] regression head.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import FramepoolConfig, FramepoolModel
>>> # Initializing a Framepool combined_residual style configuration
>>> configuration = FramepoolConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = FramepoolModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/framepool/configuration_framepool.py
Python
class FramepoolConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`FramepoolModel`][multimolecule.models.FramepoolModel]. It is used to instantiate a Framepool model according to
    the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will
    yield a similar configuration to that of the Framepool ``combined_residual`` architecture released with the
    [Karollus et al., 2021](https://doi.org/10.1371/journal.pcbi.1008982) paper.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Number of one-hot input channels derived from the MultiMolecule RNA tokenizer. Defaults to 5
            (``A``, ``C``, ``G``, ``U``, ``N``), matching the MultiMolecule RNA ``streamline`` alphabet. The upstream
            checkpoint only learns the four canonical channels, with upstream ``T`` exposed as ``U``; the embedding
            zeroes the ``N`` channel, matching the upstream ``compute_pad_mask`` semantics.
        null_channel_id:
            Channel index that represents the upstream "no nucleobase" token (``N`` and padding). The embedding
            zeroes this column so that the upstream ``pad_mask = sum(one_hot, axis=2)`` mask correctly identifies
            padded positions. Set to ``None`` to keep all channels.
        num_conv_layers:
            Number of stacked length-preserving residual convolutions in the encoder.
        conv_channels:
            Number of output channels for every convolution in the encoder.
        kernel_size:
            Kernel sizes of the encoder convolutions. Either a scalar shared across all layers, or a list with one
            entry per layer.
        dilations:
            Dilation rates of the encoder convolutions. Either a scalar shared across all layers, or a list with one
            entry per layer.
        hidden_act:
            Non-linear activation applied after each encoder convolution.
        padding:
            Convolution padding mode. ``same`` keeps the sequence length; ``causal`` left-pads to retain it.
        skip_connections:
            ``residual`` adds the input of every conv past the first to its output (the configuration used by the
            released checkpoint). ``""`` disables skip connections.
        num_dense_layers:
            Number of fully-connected layers between the frame-pooled representation and the unscaled MRL output.
        dense_sizes:
            Hidden sizes of the fully-connected layers. Length must match ``num_dense_layers``.
        dense_dropout:
            Dropout probability applied after every fully-connected layer.
        only_max_pool:
            If ``True``, the frame pooler concatenates only the per-frame global max pooled features (3 vectors).
            Otherwise it additionally concatenates the masked global average pooled features (6 vectors), matching
            the released checkpoint.
        library_size:
            Number of training sub-libraries supported by the scaling regression head. The released checkpoint was
            trained jointly on the ``egfp_unmod_1`` and ``random`` libraries, so ``library_size = 2``.
        library_index:
            Default training sub-library index used to construct the one-hot library indicator at inference. Matches
            the ``random`` library used for variant effect prediction upstream (Kipoi ``UTRVariantEffectModel``).
        num_labels:
            Number of scalar outputs of the model. Framepool predicts a single scalar mean ribosome load value.
        head:
            Configuration of the [`FramepoolForSequencePrediction`] regression head.

    Examples:
        >>> from multimolecule import FramepoolConfig, FramepoolModel
        >>> # Initializing a Framepool combined_residual style configuration
        >>> configuration = FramepoolConfig()
        >>> # Initializing a model (with random weights) from the configuration
        >>> model = FramepoolModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "framepool"

    def __init__(
        self,
        vocab_size: int = 5,
        null_channel_id: int | None = 4,
        num_conv_layers: int = 3,
        conv_channels: int = 128,
        kernel_size: int | list[int] = 7,
        dilations: int | list[int] = 1,
        hidden_act: str = "relu",
        padding: str = "same",
        skip_connections: str = "residual",
        num_dense_layers: int = 1,
        dense_sizes: list[int] | None = None,
        dense_dropout: float = 0.2,
        only_max_pool: bool = False,
        library_size: int = 2,
        library_index: int = 1,
        num_labels: int = 1,
        head: HeadConfig | None = None,
        **kwargs,
    ):
        super().__init__(num_labels=num_labels, **kwargs)
        if num_labels != 1:
            raise ValueError(f"Framepool predicts one mean-ribosome-load scalar, got num_labels={num_labels}.")
        if num_conv_layers <= 0:
            raise ValueError(f"num_conv_layers must be positive, got {num_conv_layers}.")
        kernel_size_list = [kernel_size] * num_conv_layers if isinstance(kernel_size, int) else list(kernel_size)
        dilations_list = [dilations] * num_conv_layers if isinstance(dilations, int) else list(dilations)
        if len(kernel_size_list) != num_conv_layers:
            raise ValueError(
                f"kernel_size must have num_conv_layers={num_conv_layers} entries, got {len(kernel_size_list)}."
            )
        if len(dilations_list) != num_conv_layers:
            raise ValueError(
                f"dilations must have num_conv_layers={num_conv_layers} entries, got {len(dilations_list)}."
            )
        if dense_sizes is None:
            dense_sizes = [64] * num_dense_layers
        if len(dense_sizes) != num_dense_layers:
            raise ValueError(
                f"dense_sizes must have num_dense_layers={num_dense_layers} entries, got {len(dense_sizes)}."
            )
        if padding not in ("same", "causal"):
            raise ValueError(f"padding must be 'same' or 'causal', got {padding!r}.")
        if skip_connections not in ("", "residual"):
            raise ValueError(f"skip_connections must be '' or 'residual', got {skip_connections!r}.")
        if library_size <= 0:
            raise ValueError(f"library_size must be positive, got {library_size}.")
        if not 0 <= library_index < library_size:
            raise ValueError(
                f"library_index must satisfy 0 <= library_index < library_size={library_size}, got {library_index}."
            )

        if null_channel_id is not None and not 0 <= null_channel_id < vocab_size:
            raise ValueError(
                f"null_channel_id must satisfy 0 <= null_channel_id < vocab_size={vocab_size}, "
                f"got {null_channel_id}."
            )

        self.vocab_size = vocab_size
        self.null_channel_id = null_channel_id
        self.num_conv_layers = num_conv_layers
        self.conv_channels = conv_channels
        self.kernel_size = kernel_size_list
        self.dilations = dilations_list
        self.hidden_act = hidden_act
        self.padding = padding
        self.skip_connections = skip_connections
        self.num_dense_layers = num_dense_layers
        self.dense_sizes = dense_sizes
        self.dense_dropout = dense_dropout
        self.only_max_pool = only_max_pool
        self.library_size = library_size
        self.library_index = library_index
        if head is None:
            head = HeadConfig(num_labels=num_labels, problem_type="regression")
        elif not isinstance(head, HeadConfig):
            head = HeadConfig(**head)
            if head.problem_type is None:
                head.problem_type = "regression"
        self.head = head

    @property
    def hidden_size(self) -> int:
        """Dimensionality of the frame-pooled representation consumed by the prediction head.

        Derived from ``conv_channels`` and ``only_max_pool``; read-only to prevent silent drift.
        ``only_max_pool=True`` concatenates 3 per-frame max-pooled vectors; otherwise 6 (max + mean).
        """
        num_pools = 3 if self.only_max_pool else 6
        return self.conv_channels * num_pools

hidden_size property

Python
hidden_size: int

Dimensionality of the frame-pooled representation consumed by the prediction head.

Derived from conv_channels and only_max_pool; read-only to prevent silent drift. only_max_pool=True concatenates 3 per-frame max-pooled vectors; otherwise 6 (max + mean).

FramepoolForSequencePrediction

Bases: FramepoolPreTrainedModel

Framepool with a sequence-level prediction head.

logits always holds the mean ribosome load (MRL) prediction of the (reference) sequence. When called with both a reference and an alternative sequence, the alternative MRL is returned in alternative_logits and the log2(alternative / reference) MRL fold change in delta, matching the upstream Kipoi UTRVariantEffectModel variant effect interface.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import FramepoolConfig, FramepoolForSequencePrediction, RnaTokenizer
>>> config = FramepoolConfig()
>>> model = FramepoolForSequencePrediction(config)
>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/framepool")
>>> input = tokenizer("ACGUACGUACGU", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1.0]]))
>>> output["logits"].shape
torch.Size([1, 1])
>>> alternative = tokenizer("ACGUACGUACGA", return_tensors="pt")
>>> output = model(**input, alternative_input_ids=alternative["input_ids"])
>>> output["delta"].shape
torch.Size([1, 1])
Source code in multimolecule/models/framepool/modeling_framepool.py
Python
class FramepoolForSequencePrediction(FramepoolPreTrainedModel):
    """
    Framepool with a sequence-level prediction head.

    `logits` always holds the mean ribosome load (MRL) prediction of the (reference) sequence. When called with both a
    reference and an alternative sequence, the alternative MRL is returned in `alternative_logits` and the
    ``log2(alternative / reference)`` MRL fold change in `delta`, matching the upstream Kipoi ``UTRVariantEffectModel``
    variant effect interface.

    Examples:
        >>> import torch
        >>> from multimolecule import FramepoolConfig, FramepoolForSequencePrediction, RnaTokenizer
        >>> config = FramepoolConfig()
        >>> model = FramepoolForSequencePrediction(config)
        >>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/framepool")
        >>> input = tokenizer("ACGUACGUACGU", return_tensors="pt")
        >>> output = model(**input, labels=torch.tensor([[1.0]]))
        >>> output["logits"].shape
        torch.Size([1, 1])
        >>> alternative = tokenizer("ACGUACGUACGA", return_tensors="pt")
        >>> output = model(**input, alternative_input_ids=alternative["input_ids"])
        >>> output["delta"].shape
        torch.Size([1, 1])
    """

    def __init__(self, config: FramepoolConfig):
        super().__init__(config)
        self.model = FramepoolModel(config)
        head_config = HeadConfig(config.head) if config.head is not None else HeadConfig()
        if head_config.num_labels is None:
            head_config.num_labels = config.num_labels
        if head_config.problem_type is None:
            head_config.problem_type = "regression"
        self.head_config = head_config
        self.criterion = Criterion(head_config)
        self.prediction = FramepoolMrlHead(config)
        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        return ["mean_ribosome_load"]

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        library_indicator: Tensor | None = None,
        alternative_input_ids: Tensor | NestedTensor | None = None,
        alternative_attention_mask: Tensor | None = None,
        alternative_inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[Tensor, ...] | FramepoolForSequencePredictorOutput:
        reference = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )
        reference_logits = self.prediction(reference.pooler_output, library_indicator=library_indicator)

        has_alternative = alternative_input_ids is not None or alternative_inputs_embeds is not None
        if has_alternative:
            alternative = self.model(
                alternative_input_ids,
                attention_mask=alternative_attention_mask,
                inputs_embeds=alternative_inputs_embeds,
                return_dict=True,
                **kwargs,
            )
            alternative_logits = self.prediction(alternative.pooler_output, library_indicator=library_indicator)
            # ``log2(alt / ref)`` matches the Kipoi `UTRVariantEffectModel.predict_on_batch` MRL fold-change output.
            delta = torch.log2(alternative_logits / reference_logits)
            loss = self.criterion(delta, labels) if labels is not None else None
            return FramepoolForSequencePredictorOutput(
                loss=loss,
                logits=reference_logits,
                alternative_logits=alternative_logits,
                delta=delta,
            )

        loss = self.criterion(reference_logits, labels) if labels is not None else None
        return FramepoolForSequencePredictorOutput(loss=loss, logits=reference_logits)

FramepoolForSequencePredictorOutput dataclass

Bases: ModelOutput

Base class for outputs of [FramepoolForSequencePrediction].

Parameters:

Name Type Description Default

loss

`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided

Regression loss.

None

logits

`torch.FloatTensor` of shape `(batch_size, num_labels)`

Mean ribosome load (MRL) prediction of the (reference) sequence.

None

alternative_logits

`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*

MRL prediction of the alternative sequence, returned when alternative_input_ids is provided.

None

delta

`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*

log2(alternative / reference) MRL fold change, returned when an alternative sequence is provided.

None
Source code in multimolecule/models/framepool/modeling_framepool.py
Python
@dataclass
class FramepoolForSequencePredictorOutput(ModelOutput):
    """
    Base class for outputs of [`FramepoolForSequencePrediction`].

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Regression loss.
        logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
            Mean ribosome load (MRL) prediction of the (reference) sequence.
        alternative_logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*):
            MRL prediction of the alternative sequence, returned when `alternative_input_ids` is provided.
        delta (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*):
            `log2(alternative / reference)` MRL fold change, returned when an alternative sequence is provided.
    """

    loss: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    alternative_logits: torch.FloatTensor | None = None
    delta: torch.FloatTensor | None = None

FramepoolModel

Bases: FramepoolPreTrainedModel

The bare Framepool model, producing a frame-aware representation from a 5’UTR sequence.

Framepool replaces the fixed-length flatten of Sample et al., 2019 with a frame-aware pooling layer that splits the convolutional feature map into the three reading frames relative to the start codon and pools each frame independently. The resulting representation is length-independent.

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> from multimolecule import FramepoolConfig, FramepoolModel, RnaTokenizer
>>> config = FramepoolConfig()
>>> model = FramepoolModel(config)
>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/framepool")
>>> input = tokenizer("ACGUACGUACGU", return_tensors="pt")
>>> output = model(**input)
>>> output["pooler_output"].shape
torch.Size([1, 768])
Source code in multimolecule/models/framepool/modeling_framepool.py
Python
class FramepoolModel(FramepoolPreTrainedModel):
    """
    The bare Framepool model, producing a frame-aware representation from a 5'UTR sequence.

    Framepool replaces the fixed-length flatten of [Sample et al., 2019](https://doi.org/10.1038/s41587-019-0164-5)
    with a frame-aware pooling layer that splits the convolutional feature map into the three reading frames relative
    to the start codon and pools each frame independently. The resulting representation is length-independent.

    Examples:
        >>> from multimolecule import FramepoolConfig, FramepoolModel, RnaTokenizer
        >>> config = FramepoolConfig()
        >>> model = FramepoolModel(config)
        >>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/framepool")
        >>> input = tokenizer("ACGUACGUACGU", return_tensors="pt")
        >>> output = model(**input)
        >>> output["pooler_output"].shape
        torch.Size([1, 768])
    """

    def __init__(self, config: FramepoolConfig):
        super().__init__(config)
        self.embeddings = FramepoolEmbedding(config)
        self.encoder = FramepoolEncoder(config)
        self.pooler = FramepoolPooler(config)
        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> FramepoolModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        # ``(batch, vocab_size, length)``; padding tokens (and tokens outside the nucleobase alphabet)
        # are encoded as all-zero columns.
        embedding_output, pad_mask = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )

        hidden_state = self.encoder(embedding_output, pad_mask)
        pooled_output = self.pooler(hidden_state, pad_mask)

        return FramepoolModelOutput(
            pooler_output=pooled_output,
            last_hidden_state=hidden_state.transpose(1, 2),
        )

FramepoolModelOutput dataclass

Bases: ModelOutput

Base class for outputs of the Framepool model.

Parameters:

Name Type Description Default

pooler_output

`torch.FloatTensor` of shape `(batch_size, hidden_size)`

The concatenation of per-frame max (and optionally average) pooled feature vectors consumed by the sequence-level prediction head.

None

last_hidden_state

`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_channels)`

The encoder feature map before the frame-aware pooling, with padded positions zeroed out.

None
Source code in multimolecule/models/framepool/modeling_framepool.py
Python
@dataclass
class FramepoolModelOutput(ModelOutput):
    """
    Base class for outputs of the Framepool model.

    Args:
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            The concatenation of per-frame max (and optionally average) pooled feature vectors consumed by the
            sequence-level prediction head.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_channels)`):
            The encoder feature map before the frame-aware pooling, with padded positions zeroed out.
    """

    pooler_output: torch.FloatTensor | None = None
    last_hidden_state: torch.FloatTensor | None = None

FramepoolPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/framepool/modeling_framepool.py
Python
class FramepoolPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = FramepoolConfig
    base_model_prefix = "model"
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["FramepoolEncoder"]

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, (nn.Conv1d, nn.Linear)):
            init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)