跳转至

Framepool

Frame-aware pooling convolutional network for predicting mean ribosome load from variable-length 5’UTR sequences.

Disclaimer

This is an UNOFFICIAL implementation of Predicting mean ribosome load for 5’UTR of any length using deep learning by Alexander Karollus et al.

The OFFICIAL repository of Framepool is at Karollus/5UTR and the published Kipoi wrapper is at kipoi/models.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing Framepool did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

Framepool is a small 1D convolutional network that predicts the mean ribosome load (MRL) of a human 5’ untranslated region from sequence alone. It extends the fixed-length network of Sample et al., 2019 with a frame-aware pooling layer that reverses the sequence to anchor reading frames at the start codon, slices the convolutional feature map into the three reading frames, and applies global max and masked global average pooling per frame. The pooled representation is length-independent and is consumed by a small dense head followed by a per-sub-library scaling regression that recalibrates the prediction across the two training libraries (egfp_unmod_1 and random). Please refer to the Training Details section for more information on the training process.

The released combined_residual checkpoint is recommended by the upstream authors for variant effect scoring; it is the checkpoint exposed by the official Kipoi Framepool entry.

Model Specification

Num Layers Hidden Size Num Parameters (M) FLOPs (G) MACs (G) Max Num Tokens
4 768 0.28 0.05 0.02 unlimited

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Mean Ribosome Load Prediction

You can use this model directly to predict the mean ribosome load of a 5’UTR sequence:

Python
1
2
3
4
5
6
7
8
>>> from multimolecule import DnaTokenizer, FramepoolForSequencePrediction

>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/framepool")
>>> model = FramepoolForSequencePrediction.from_pretrained("multimolecule/framepool")
>>> output = model(**tokenizer("ACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTAC", return_tensors="pt"))

>>> output.keys()
odict_keys(['logits'])

Interface

  • Input length: variable; the upstream MPRA training data is 25-100 nt 5’UTR but the model accepts any length because of frame-aware pooling
  • Alphabet: DNA (A, C, G, T); N and other non-canonical tokens are encoded as all-zero columns and ignored by the masked pooling
  • Padding: zero-padding is supported via attention_mask and is excluded from pooling
  • Output: single scalar per sequence — predicted mean ribosome load (logits, shape (batch_size, 1))
  • Auxiliary inputs: optional library_indicator (shape (batch_size, library_size)) selecting one of the two training sub-libraries for the scaling regression. Defaults to the random library, matching the upstream Kipoi variant effect interface

Variant Effect

Framepool supports paired reference/alternative scoring through the optional alternative_input_ids argument:

  • Single sequence (reference only): logits is the predicted mean ribosome load (one scalar per sequence)
  • Reference + alternative: logits is the log2 mean ribosome load fold change log2(MRL_alt / MRL_ref), matching the Kipoi UTRVariantEffectModel.predict_on_batch mrl_fold_change output
  • Reference and alternative sequences are scored independently; both must use the same library_indicator so that the scaling regression cancels out of the fold change
  • For the upstream “shifted-frame” variant effect outputs (shift_1, shift_2), prepend one or two zero columns (or N tokens) to both reference and alternative inputs before scoring, matching the Kipoi loop

Training Details

Framepool was trained on polysome-profiling MPRA data measuring the mean ribosome load of randomized 5’UTR sequences and uses frame-aware pooling so that a single network can score sequences of arbitrary length.

Training Data

Framepool was trained on the eGFP polysome-profiling MPRA libraries of Sample et al., 2019 in HEK293T cells: the fixed-length library (egfp_unmod_1, 50 nt) and the variable-length library (random, 25-100 nt). Approximately 260,000 sequences were used for training, with 20,000 held out for testing; additional validation was performed on endogenous data.

Training Procedure

Pre-training

  • Loss: mean squared error between the predicted and measured mean ribosome load
  • Optimizer: Adam with lr = 1e-3, beta_1 = 0.9, beta_2 = 0.999, epsilon = 1e-8
  • Epochs: 6
  • Mini-batch sampling: the two training libraries are mixed within every batch; a one-hot library indicator is fed to the scaling regression layer so that the network can absorb the library-specific offset

Citation

BibTeX
@article{karollus2021predicting,
  author    = {Karollus, Alexander and Avsec, {\v Z}iga and Gagneur, Julien},
  title     = {Predicting mean ribosome load for 5{\textquoteright}UTR of any length using deep learning},
  journal   = {PLOS Computational Biology},
  volume    = {17},
  number    = {5},
  pages     = {e1008982},
  year      = {2021},
  publisher = {Public Library of Science},
  doi       = {10.1371/journal.pcbi.1008982}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the Framepool paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.framepool

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

FramepoolConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a FramepoolModel. It is used to instantiate a Framepool model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Framepool combined_residual architecture released with the Karollus et al., 2021 paper.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Number of one-hot input channels derived from the MultiMolecule DNA tokenizer. Defaults to 5 (A, C, G, T, N), matching the MultiMolecule DNA streamline alphabet. The upstream checkpoint only learns the first four ACGT channels; the N channel stays zero, matching the upstream compute_pad_mask semantics.

5

null_channel_id

int | None

Channel index that represents the upstream “no nucleobase” token (N and padding). The embedding zeroes this column so that the upstream pad_mask = sum(one_hot, axis=2) mask correctly identifies padded positions. Set to None to keep all channels.

4

num_conv_layers

int

Number of stacked length-preserving residual convolutions in the encoder.

3

num_filters

int

Number of output channels for every convolution in the encoder.

128

kernel_size

int | list[int]

Kernel sizes of the encoder convolutions. Either a scalar shared across all layers, or a list with one entry per layer.

7

dilations

int | list[int]

Dilation rates of the encoder convolutions. Either a scalar shared across all layers, or a list with one entry per layer.

1

hidden_act

str

Non-linear activation applied after each encoder convolution.

'relu'

padding

str

Convolution padding mode. same keeps the sequence length; causal left-pads to retain it.

'same'

skip_connections

str

residual adds the input of every conv past the first to its output (the configuration used by the released checkpoint). "" disables skip connections.

'residual'

num_dense_layers

int

Number of fully-connected layers between the frame-pooled representation and the unscaled MRL output.

1

dense_sizes

list[int] | None

Hidden sizes of the fully-connected layers. Length must match num_dense_layers.

None

dense_dropout

float

Dropout probability applied after every fully-connected layer.

0.2

only_max_pool

bool

If True, the frame pooler concatenates only the per-frame global max pooled features (3 vectors). Otherwise it additionally concatenates the masked global average pooled features (6 vectors), matching the released checkpoint.

False

library_size

int

Number of training sub-libraries supported by the scaling regression head. The released checkpoint was trained jointly on the egfp_unmod_1 and random libraries, so library_size = 2.

2

library_index

int

Default training sub-library index used to construct the one-hot library indicator at inference. Matches the random library used for variant effect prediction upstream (Kipoi UTRVariantEffectModel).

1

num_labels

int

Number of scalar outputs of the model. Framepool predicts a single scalar mean ribosome load value.

1

head

HeadConfig | None

Configuration of the [FramepoolForSequencePrediction] regression head.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import FramepoolConfig, FramepoolModel
>>> # Initializing a Framepool combined_residual style configuration
>>> configuration = FramepoolConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = FramepoolModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/framepool/configuration_framepool.py
Python
class FramepoolConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`FramepoolModel`][multimolecule.models.FramepoolModel]. It is used to instantiate a Framepool model according to
    the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will
    yield a similar configuration to that of the Framepool ``combined_residual`` architecture released with the
    [Karollus et al., 2021](https://doi.org/10.1371/journal.pcbi.1008982) paper.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Number of one-hot input channels derived from the MultiMolecule DNA tokenizer. Defaults to 5
            (``A``, ``C``, ``G``, ``T``, ``N``), matching the MultiMolecule DNA ``streamline`` alphabet. The
            upstream checkpoint only learns the first four ``ACGT`` channels; the ``N`` channel stays zero,
            matching the upstream ``compute_pad_mask`` semantics.
        null_channel_id:
            Channel index that represents the upstream "no nucleobase" token (``N`` and padding). The embedding
            zeroes this column so that the upstream ``pad_mask = sum(one_hot, axis=2)`` mask correctly identifies
            padded positions. Set to ``None`` to keep all channels.
        num_conv_layers:
            Number of stacked length-preserving residual convolutions in the encoder.
        num_filters:
            Number of output channels for every convolution in the encoder.
        kernel_size:
            Kernel sizes of the encoder convolutions. Either a scalar shared across all layers, or a list with one
            entry per layer.
        dilations:
            Dilation rates of the encoder convolutions. Either a scalar shared across all layers, or a list with one
            entry per layer.
        hidden_act:
            Non-linear activation applied after each encoder convolution.
        padding:
            Convolution padding mode. ``same`` keeps the sequence length; ``causal`` left-pads to retain it.
        skip_connections:
            ``residual`` adds the input of every conv past the first to its output (the configuration used by the
            released checkpoint). ``""`` disables skip connections.
        num_dense_layers:
            Number of fully-connected layers between the frame-pooled representation and the unscaled MRL output.
        dense_sizes:
            Hidden sizes of the fully-connected layers. Length must match ``num_dense_layers``.
        dense_dropout:
            Dropout probability applied after every fully-connected layer.
        only_max_pool:
            If ``True``, the frame pooler concatenates only the per-frame global max pooled features (3 vectors).
            Otherwise it additionally concatenates the masked global average pooled features (6 vectors), matching
            the released checkpoint.
        library_size:
            Number of training sub-libraries supported by the scaling regression head. The released checkpoint was
            trained jointly on the ``egfp_unmod_1`` and ``random`` libraries, so ``library_size = 2``.
        library_index:
            Default training sub-library index used to construct the one-hot library indicator at inference. Matches
            the ``random`` library used for variant effect prediction upstream (Kipoi ``UTRVariantEffectModel``).
        num_labels:
            Number of scalar outputs of the model. Framepool predicts a single scalar mean ribosome load value.
        head:
            Configuration of the [`FramepoolForSequencePrediction`] regression head.

    Examples:
        >>> from multimolecule import FramepoolConfig, FramepoolModel
        >>> # Initializing a Framepool combined_residual style configuration
        >>> configuration = FramepoolConfig()
        >>> # Initializing a model (with random weights) from the configuration
        >>> model = FramepoolModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "framepool"

    def __init__(
        self,
        vocab_size: int = 5,
        null_channel_id: int | None = 4,
        num_conv_layers: int = 3,
        num_filters: int = 128,
        kernel_size: int | list[int] = 7,
        dilations: int | list[int] = 1,
        hidden_act: str = "relu",
        padding: str = "same",
        skip_connections: str = "residual",
        num_dense_layers: int = 1,
        dense_sizes: list[int] | None = None,
        dense_dropout: float = 0.2,
        only_max_pool: bool = False,
        library_size: int = 2,
        library_index: int = 1,
        num_labels: int = 1,
        head: HeadConfig | None = None,
        **kwargs,
    ):
        super().__init__(num_labels=num_labels, **kwargs)
        if num_labels != 1:
            raise ValueError(f"Framepool predicts one mean-ribosome-load scalar, got num_labels={num_labels}.")
        if num_conv_layers <= 0:
            raise ValueError(f"num_conv_layers must be positive, got {num_conv_layers}.")
        kernel_size_list = [kernel_size] * num_conv_layers if isinstance(kernel_size, int) else list(kernel_size)
        dilations_list = [dilations] * num_conv_layers if isinstance(dilations, int) else list(dilations)
        if len(kernel_size_list) != num_conv_layers:
            raise ValueError(
                f"kernel_size must have num_conv_layers={num_conv_layers} entries, got {len(kernel_size_list)}."
            )
        if len(dilations_list) != num_conv_layers:
            raise ValueError(
                f"dilations must have num_conv_layers={num_conv_layers} entries, got {len(dilations_list)}."
            )
        if dense_sizes is None:
            dense_sizes = [64] * num_dense_layers
        if len(dense_sizes) != num_dense_layers:
            raise ValueError(
                f"dense_sizes must have num_dense_layers={num_dense_layers} entries, got {len(dense_sizes)}."
            )
        if padding not in ("same", "causal"):
            raise ValueError(f"padding must be 'same' or 'causal', got {padding!r}.")
        if skip_connections not in ("", "residual"):
            raise ValueError(f"skip_connections must be '' or 'residual', got {skip_connections!r}.")
        if library_size <= 0:
            raise ValueError(f"library_size must be positive, got {library_size}.")
        if not 0 <= library_index < library_size:
            raise ValueError(
                f"library_index must satisfy 0 <= library_index < library_size={library_size}, got {library_index}."
            )

        if null_channel_id is not None and not 0 <= null_channel_id < vocab_size:
            raise ValueError(
                f"null_channel_id must satisfy 0 <= null_channel_id < vocab_size={vocab_size}, "
                f"got {null_channel_id}."
            )

        self.vocab_size = vocab_size
        self.null_channel_id = null_channel_id
        self.num_conv_layers = num_conv_layers
        self.num_filters = num_filters
        self.kernel_size = kernel_size_list
        self.dilations = dilations_list
        self.hidden_act = hidden_act
        self.padding = padding
        self.skip_connections = skip_connections
        self.num_dense_layers = num_dense_layers
        self.dense_sizes = dense_sizes
        self.dense_dropout = dense_dropout
        self.only_max_pool = only_max_pool
        self.library_size = library_size
        self.library_index = library_index
        # The pooled representation concatenates per-frame max (and optionally average) pooled feature vectors.
        num_pools = 3 if only_max_pool else 6
        self.hidden_size = num_filters * num_pools
        if head is None:
            head = HeadConfig(num_labels=num_labels, problem_type="regression")
        elif not isinstance(head, HeadConfig):
            head = HeadConfig(**head)
            if head.problem_type is None:
                head.problem_type = "regression"
        self.head = head

FramepoolForSequencePrediction

Bases: FramepoolPreTrainedModel

Framepool with a sequence-level prediction head.

When called with a single sequence the head returns the unscaled mean ribosome load (MRL) prediction. When called with both a reference and an alternative sequence it returns the log2 mean ribosome load fold change (log2(alternative / reference)), matching the upstream Kipoi UTRVariantEffectModel variant effect interface.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import FramepoolConfig, FramepoolForSequencePrediction, DnaTokenizer
>>> config = FramepoolConfig()
>>> model = FramepoolForSequencePrediction(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/framepool")
>>> input = tokenizer("ACGTACGTACGT", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1.0]]))
>>> output["logits"].shape
torch.Size([1, 1])
>>> alternative = tokenizer("ACGTACGTACGA", return_tensors="pt")
>>> output = model(**input, alternative_input_ids=alternative["input_ids"])
>>> output["logits"].shape
torch.Size([1, 1])
Source code in multimolecule/models/framepool/modeling_framepool.py
Python
class FramepoolForSequencePrediction(FramepoolPreTrainedModel):
    """
    Framepool with a sequence-level prediction head.

    When called with a single sequence the head returns the unscaled mean ribosome load (MRL) prediction. When called
    with both a reference and an alternative sequence it returns the ``log2`` mean ribosome load fold change
    (``log2(alternative / reference)``), matching the upstream Kipoi
    ``UTRVariantEffectModel`` variant effect interface.

    Examples:
        >>> import torch
        >>> from multimolecule import FramepoolConfig, FramepoolForSequencePrediction, DnaTokenizer
        >>> config = FramepoolConfig()
        >>> model = FramepoolForSequencePrediction(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/framepool")
        >>> input = tokenizer("ACGTACGTACGT", return_tensors="pt")
        >>> output = model(**input, labels=torch.tensor([[1.0]]))
        >>> output["logits"].shape
        torch.Size([1, 1])
        >>> alternative = tokenizer("ACGTACGTACGA", return_tensors="pt")
        >>> output = model(**input, alternative_input_ids=alternative["input_ids"])
        >>> output["logits"].shape
        torch.Size([1, 1])
    """

    def __init__(self, config: FramepoolConfig):
        super().__init__(config)
        self.model = FramepoolModel(config)
        head_config = HeadConfig(config.head) if config.head is not None else HeadConfig()
        if head_config.num_labels is None:
            head_config.num_labels = config.num_labels
        if head_config.problem_type is None:
            head_config.problem_type = "regression"
        self.head_config = head_config
        self.criterion = Criterion(head_config)
        self.prediction = FramepoolMrlHead(config)
        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        return ["mean_ribosome_load"]

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        library_indicator: Tensor | None = None,
        alternative_input_ids: Tensor | NestedTensor | None = None,
        alternative_attention_mask: Tensor | None = None,
        alternative_inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Tuple[Tensor, ...] | SequencePredictorOutput:
        reference = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )
        reference_logits = self.prediction(reference.pooler_output, library_indicator=library_indicator)

        has_alternative = alternative_input_ids is not None or alternative_inputs_embeds is not None
        if has_alternative:
            alternative = self.model(
                alternative_input_ids,
                attention_mask=alternative_attention_mask,
                inputs_embeds=alternative_inputs_embeds,
                return_dict=True,
                **kwargs,
            )
            alternative_logits = self.prediction(alternative.pooler_output, library_indicator=library_indicator)
            # ``log2(alt / ref)`` matches the Kipoi `UTRVariantEffectModel.predict_on_batch` MRL fold-change output.
            logits = torch.log2(alternative_logits / reference_logits)
            loss = self.criterion(logits, labels) if labels is not None else None
            return SequencePredictorOutput(loss=loss, logits=logits)

        logits = reference_logits
        loss = self.criterion(logits, labels) if labels is not None else None
        return SequencePredictorOutput(loss=loss, logits=logits)

FramepoolModel

Bases: FramepoolPreTrainedModel

The bare Framepool model, producing a frame-aware representation from a 5’UTR sequence.

Framepool replaces the fixed-length flatten of Sample et al., 2019 with a frame-aware pooling layer that splits the convolutional feature map into the three reading frames relative to the start codon and pools each frame independently. The resulting representation is length-independent.

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> from multimolecule import FramepoolConfig, FramepoolModel, DnaTokenizer
>>> config = FramepoolConfig()
>>> model = FramepoolModel(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/framepool")
>>> input = tokenizer("ACGTACGTACGT", return_tensors="pt")
>>> output = model(**input)
>>> output["pooler_output"].shape
torch.Size([1, 768])
Source code in multimolecule/models/framepool/modeling_framepool.py
Python
class FramepoolModel(FramepoolPreTrainedModel):
    """
    The bare Framepool model, producing a frame-aware representation from a 5'UTR sequence.

    Framepool replaces the fixed-length flatten of [Sample et al., 2019](https://doi.org/10.1038/s41587-019-0164-5)
    with a frame-aware pooling layer that splits the convolutional feature map into the three reading frames relative
    to the start codon and pools each frame independently. The resulting representation is length-independent.

    Examples:
        >>> from multimolecule import FramepoolConfig, FramepoolModel, DnaTokenizer
        >>> config = FramepoolConfig()
        >>> model = FramepoolModel(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/framepool")
        >>> input = tokenizer("ACGTACGTACGT", return_tensors="pt")
        >>> output = model(**input)
        >>> output["pooler_output"].shape
        torch.Size([1, 768])
    """

    def __init__(self, config: FramepoolConfig):
        super().__init__(config)
        self.config = config
        self.embeddings = FramepoolEmbedding(config)
        self.encoder = FramepoolEncoder(config)
        self.pooler = FramepoolPooler(config)
        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> FramepoolModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        # ``(batch, vocab_size, length)``; padding tokens (and tokens outside the nucleobase alphabet)
        # are encoded as all-zero columns.
        embedding_output, pad_mask = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )

        hidden_state = self.encoder(embedding_output, pad_mask)
        pooled_output = self.pooler(hidden_state, pad_mask)

        return FramepoolModelOutput(
            pooler_output=pooled_output,
            last_hidden_state=hidden_state.transpose(1, 2),
        )

FramepoolModelOutput dataclass

Bases: ModelOutput

Base class for outputs of the Framepool model.

Parameters:

Name Type Description Default

pooler_output

`torch.FloatTensor` of shape `(batch_size, hidden_size)`

The concatenation of per-frame max (and optionally average) pooled feature vectors consumed by the sequence-level prediction head.

None

last_hidden_state

`torch.FloatTensor` of shape `(batch_size, sequence_length, num_filters)`

The encoder feature map before the frame-aware pooling, with padded positions zeroed out.

None
Source code in multimolecule/models/framepool/modeling_framepool.py
Python
@dataclass
class FramepoolModelOutput(ModelOutput):
    """
    Base class for outputs of the Framepool model.

    Args:
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            The concatenation of per-frame max (and optionally average) pooled feature vectors consumed by the
            sequence-level prediction head.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_filters)`):
            The encoder feature map before the frame-aware pooling, with padded positions zeroed out.
    """

    pooler_output: torch.FloatTensor | None = None
    last_hidden_state: torch.FloatTensor | None = None

FramepoolPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/framepool/modeling_framepool.py
Python
class FramepoolPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = FramepoolConfig
    base_model_prefix = "model"
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["FramepoolEncoder"]

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, (nn.Conv1d, nn.Linear)):
            init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)