跳转至

Enformer

Enformer

Transformer-based deep neural network for predicting genomic coverage tracks from long DNA sequences with long-range context.

Disclaimer

This is an UNOFFICIAL implementation of Effective gene expression prediction from sequence by integrating long-range interactions by Žiga Avsec, Vikram Agarwal, Daniel Visentin et al.

The OFFICIAL repository of Enformer is at google-deepmind/deepmind-research/enformer.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing Enformer did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

Enformer is the successor of Basenji. It replaces Basenji’s dilated convolution tower with a convolution stem followed by a Transformer trunk, which lets it model long-range genomic interactions. It consumes a long DNA window (~393 kb), passes it through a convolution + attention-pooling stem that downsamples the sequence by 2 ** 7 = 128x, processes the binned representation with 11 Transformer blocks using Transformer-XL style relative positional encoding, center-crops to 896 output bins, and applies a pointwise head plus a per-species linear track projection with a softplus activation. The prediction is binned: the output has shape (batch_size, target_length, num_tracks) where each bin summarizes 128 bp of sequence and num_tracks is the number of genomic coverage experiments for the selected species.

Model Specification

Input Length Bin Size Output Bins Hidden Size Layers Heads Num Labels Num Parameters (M)
393216 128 896 1536 11 8 5313 246.2

The default table reports the human output head. The mouse head predicts 1643 tracks.

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Genomic Coverage Prediction

You can use this model to predict binned genomic coverage tracks from a DNA sequence:

Python
>>> import torch
>>> from multimolecule import DnaTokenizer, EnformerConfig, EnformerForTokenPrediction

>>> config = EnformerConfig(
...     sequence_length=256, hidden_size=12, num_hidden_layers=1, num_attention_heads=2,
...     attention_head_size=4, num_downsamples=3, dim_divisible_by=2, target_length=16,
...     num_labels=4,
... )
>>> model = EnformerForTokenPrediction(config)
>>> output = model(torch.randint(config.vocab_size, (1, 256)))
>>> output.logits.shape
torch.Size([1, 16, 4])

The binned positional axis is treated as the “token” axis: each output position corresponds to one genomic bin rather than a single nucleotide. The species configuration option selects the human (5,313 tracks) or mouse (1,643 tracks) output head.

Interface

  • Input length: fixed 393,216 bp DNA window
  • Output binning: 128 bp per output bin; 896 output bins per window (after center-cropping the binned representation)
  • Species head: select human (5,313 tracks) or mouse (1,643 tracks) via the species config option
  • Output: (batch_size, target_length, num_tracks)

Training Details

Enformer was trained to predict genomic coverage tracks (DNase-seq, ATAC-seq, ChIP-seq and CAGE) from the human and mouse reference genomes.

Training Data

The model was trained on a large compendium of functional genomics experiments aligned to the human (hg38) and mouse (mm10) reference genomes. The genome was divided into overlapping windows; for each window the per-128-bp coverage of every experiment served as the regression target.

Training Procedure

Pre-training

The model was trained to minimize a Poisson regression loss between predicted and observed coverage, using a softplus output activation to keep the predicted coverage non-negative.

Citation

BibTeX
@article{avsec2021effective,
  author    = {Avsec, {\v{Z}}iga and Agarwal, Vikram and Visentin, Daniel and Ledsam, Joseph R. and Grabska-Barwinska, Agnieszka and Taylor, Kyle R. and Assael, Yannis and Jumper, John and Kohli, Pushmeet and Kelley, David R.},
  title     = {Effective gene expression prediction from sequence by integrating long-range interactions},
  journal   = {Nature Methods},
  year      = 2021,
  volume    = 18,
  number    = 10,
  pages     = {1196--1203},
  doi       = {10.1038/s41592-021-01252-x},
  publisher = {Nature Publishing Group}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the Enformer paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.enformer

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

EnformerConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a EnformerModel. It is used to instantiate an Enformer model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Enformer deepmind/enformer architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Enformer is the successor of Basenji. It replaces Basenji’s dilated convolution tower with a convolution stem followed by a Transformer trunk so it can model long-range genomic interactions. A long DNA window of sequence_length base pairs is downsampled by the convolution stem (2 ** num_downsamples, i.e. 128 bp per bin by default), processed by the Transformer trunk, cropped to target_length bins, and projected to genomic coverage tracks. The output has shape (batch_size, target_length, num_labels).

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the Enformer model. Defines the number of input feature channels derived from the MultiMolecule DNA token order. Defaults to 5 (A, C, G, T, N).

5

sequence_length

int

The length, in base pairs, of the input DNA window. Defaults to 393216 (~393 kb).

393216

hidden_size

int

Dimensionality of the Transformer trunk. The convolution stem’s first conv produces hidden_size // 2 channels and the conv tower grows back to hidden_size.

1536

num_hidden_layers

int

Number of Transformer blocks in the trunk.

11

num_attention_heads

int

Number of attention heads in each Transformer block.

8

attention_head_size

int

Dimensionality of the query/key projection per head.

64

num_downsamples

int

Total number of 2x downsampling steps applied by the convolution stem. The binning factor is 2 ** num_downsamples (128 bp per bin by default).

7

dim_divisible_by

int

The conv-tower channel sizes are rounded to a multiple of this value.

128

stem_kernel_size

int

Kernel size of the first (stem) convolution.

15

conv_tower_kernel_size

int

Kernel size of the main convolution in every conv-tower stage.

5

target_length

int

Number of output bins kept after center-cropping the trunk output.

896

head_hidden_size

int | None

Dimensionality of the pointwise output head before the final track projection. Defaults to 2 * hidden_size.

None

hidden_act

str

The non-linear activation function used by the convolution blocks and the pointwise head. Enformer uses the sigmoid GELU approximation x * sigmoid(1.702 * x), which is quick_gelu in Transformers.

'quick_gelu'

output_act

str

Activation applied to the per-track predictions. Enformer applies softplus so the predicted coverage is non-negative.

'softplus'

hidden_dropout

float

The dropout probability applied in the Transformer trunk.

0.4

attention_dropout

float

The dropout probability applied to the attention matrix.

0.05

position_dropout

float

The dropout probability applied to the relative positional features.

0.01

batch_norm_eps

float

The epsilon used by the batch normalization layers.

1e-05

batch_norm_momentum

float

The momentum used by the batch normalization layers.

0.1

species

str

Output head to expose downstream. Enformer is trained with two species heads; the selected head determines num_labels. Use human (5313 tracks) or mouse (1643 tracks).

'human'

num_labels

int | None

Number of genomic coverage tracks predicted per bin. Defaults to the track count of the selected species head.

None

head

HeadConfig | None

Head configuration for the binned track prediction head.

None

output_contexts

bool

Whether to output the context vectors for each trunk block.

False

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import EnformerConfig, EnformerModel
>>> # Initializing an Enformer multimolecule/enformer style configuration
>>> configuration = EnformerConfig()
>>> # Initializing a model (with random weights) from the multimolecule/enformer style configuration
>>> model = EnformerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/enformer/configuration_enformer.py
Python
class EnformerConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`EnformerModel`][multimolecule.models.EnformerModel]. It is used to instantiate an Enformer model according to
    the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will
    yield a similar configuration to that of the Enformer
    [deepmind/enformer](https://github.com/google-deepmind/deepmind-research/tree/master/enformer) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Enformer is the successor of Basenji. It replaces Basenji's dilated convolution tower with a
    convolution stem followed by a Transformer trunk so it can model long-range genomic
    interactions. A long DNA window of `sequence_length` base pairs is downsampled by the
    convolution stem (`2 ** num_downsamples`, i.e. 128 bp per bin by default), processed by the
    Transformer trunk, cropped to `target_length` bins, and projected to genomic coverage tracks.
    The output has shape `(batch_size, target_length, num_labels)`.

    Args:
        vocab_size:
            Vocabulary size of the Enformer model. Defines the number of input feature channels
            derived from the MultiMolecule DNA token order.
            Defaults to 5 (`A`, `C`, `G`, `T`, `N`).
        sequence_length:
            The length, in base pairs, of the input DNA window.
            Defaults to 393216 (~393 kb).
        hidden_size:
            Dimensionality of the Transformer trunk. The convolution stem's first conv produces
            `hidden_size // 2` channels and the conv tower grows back to `hidden_size`.
        num_hidden_layers:
            Number of Transformer blocks in the trunk.
        num_attention_heads:
            Number of attention heads in each Transformer block.
        attention_head_size:
            Dimensionality of the query/key projection per head.
        num_downsamples:
            Total number of 2x downsampling steps applied by the convolution stem. The binning
            factor is `2 ** num_downsamples` (128 bp per bin by default).
        dim_divisible_by:
            The conv-tower channel sizes are rounded to a multiple of this value.
        stem_kernel_size:
            Kernel size of the first (stem) convolution.
        conv_tower_kernel_size:
            Kernel size of the main convolution in every conv-tower stage.
        target_length:
            Number of output bins kept after center-cropping the trunk output.
        head_hidden_size:
            Dimensionality of the pointwise output head before the final track projection.
            Defaults to `2 * hidden_size`.
        hidden_act:
            The non-linear activation function used by the convolution blocks and the pointwise
            head. Enformer uses the sigmoid GELU approximation `x * sigmoid(1.702 * x)`, which is
            `quick_gelu` in Transformers.
        output_act:
            Activation applied to the per-track predictions. Enformer applies `softplus` so the
            predicted coverage is non-negative.
        hidden_dropout:
            The dropout probability applied in the Transformer trunk.
        attention_dropout:
            The dropout probability applied to the attention matrix.
        position_dropout:
            The dropout probability applied to the relative positional features.
        batch_norm_eps:
            The epsilon used by the batch normalization layers.
        batch_norm_momentum:
            The momentum used by the batch normalization layers.
        species:
            Output head to expose downstream. Enformer is trained with two species heads; the
            selected head determines `num_labels`. Use `human` (5313 tracks) or `mouse`
            (1643 tracks).
        num_labels:
            Number of genomic coverage tracks predicted per bin. Defaults to the track count of
            the selected `species` head.
        head:
            Head configuration for the binned track prediction head.
        output_contexts:
            Whether to output the context vectors for each trunk block.

    Examples:
        >>> from multimolecule import EnformerConfig, EnformerModel
        >>> # Initializing an Enformer multimolecule/enformer style configuration
        >>> configuration = EnformerConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/enformer style configuration
        >>> model = EnformerModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "enformer"

    species_num_tracks = {"human": 5313, "mouse": 1643}

    def __init__(
        self,
        vocab_size: int = 5,
        sequence_length: int = 393216,
        hidden_size: int = 1536,
        num_hidden_layers: int = 11,
        num_attention_heads: int = 8,
        attention_head_size: int = 64,
        num_downsamples: int = 7,
        dim_divisible_by: int = 128,
        stem_kernel_size: int = 15,
        conv_tower_kernel_size: int = 5,
        target_length: int = 896,
        head_hidden_size: int | None = None,
        hidden_act: str = "quick_gelu",
        output_act: str = "softplus",
        hidden_dropout: float = 0.4,
        attention_dropout: float = 0.05,
        position_dropout: float = 0.01,
        batch_norm_eps: float = 1e-5,
        batch_norm_momentum: float = 0.1,
        species: str = "human",
        num_labels: int | None = None,
        head: HeadConfig | None = None,
        output_contexts: bool = False,
        **kwargs,
    ):
        # Enformer is a feature-channel DNA model: it consumes a raw one-hot DNA window with no
        # special tokens, and its output is binned coverage tracks. There is no BOS/EOS/MASK token
        # on either the input or the binned positional axis, so the shared TokenPredictionHead must
        # not trim "special" bins.
        kwargs.setdefault("bos_token_id", None)
        kwargs.setdefault("eos_token_id", None)
        kwargs.setdefault("mask_token_id", None)
        kwargs.setdefault("null_token_id", None)
        if species not in self.species_num_tracks:
            raise ValueError(f"species must be one of {sorted(self.species_num_tracks)}, got {species!r}")
        if num_labels is None:
            num_labels = self.species_num_tracks[species]
        super().__init__(num_labels=num_labels, **kwargs)
        self.vocab_size = vocab_size
        self.sequence_length = sequence_length
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = attention_head_size
        self.num_downsamples = num_downsamples
        self.dim_divisible_by = dim_divisible_by
        self.stem_kernel_size = stem_kernel_size
        self.conv_tower_kernel_size = conv_tower_kernel_size
        self.target_length = target_length
        self.head_hidden_size = head_hidden_size if head_hidden_size is not None else 2 * hidden_size
        self.hidden_act = hidden_act
        self.output_act = output_act
        self.hidden_dropout = hidden_dropout
        self.attention_dropout = attention_dropout
        self.position_dropout = position_dropout
        self.batch_norm_eps = batch_norm_eps
        self.batch_norm_momentum = batch_norm_momentum
        self.species = species
        if head is None:
            head = HeadConfig(problem_type="regression")
        else:
            head = HeadConfig(head)
            if head.problem_type is None:
                head.problem_type = "regression"
        self.head = head
        self.output_contexts = output_contexts

        if self.num_downsamples < 2:
            raise ValueError(f"num_downsamples must be >= 2, got {self.num_downsamples}")
        if self.hidden_size % 2 != 0:
            raise ValueError(f"hidden_size must be even, got {self.hidden_size}")
        if self.hidden_size % self.num_attention_heads != 0:
            raise ValueError(
                f"hidden_size ({self.hidden_size}) must be divisible by num_attention_heads "
                f"({self.num_attention_heads})"
            )
        # The relative positional encoding stacks 3 basis families x 2 (signed) = 6 components, so
        # the per-head feature size must be divisible by 6.
        if (self.hidden_size // self.num_attention_heads) % 6 != 0:
            raise ValueError(
                f"hidden_size // num_attention_heads "
                f"({self.hidden_size // self.num_attention_heads}) must be divisible by 6 so the "
                f"relative positional features are well defined."
            )
        if self.pool_factor <= 0:
            raise ValueError(f"pool_factor must be positive, got {self.pool_factor}")
        if self.sequence_length % self.pool_factor != 0:
            raise ValueError(
                f"sequence_length ({self.sequence_length}) must be divisible by the total pooling "
                f"factor ({self.pool_factor}) so the binned output is well defined."
            )
        if self.num_bins < self.target_length:
            raise ValueError(
                f"target_length ({self.target_length}) must not exceed the number of binned "
                f"positions ({self.num_bins})."
            )

    @property
    def pool_factor(self) -> int:
        r"""Total downsampling factor applied by the stem, i.e. base pairs per output bin."""
        return 2**self.num_downsamples

    @property
    def num_bins(self) -> int:
        r"""Number of binned positions produced by the trunk before center-cropping."""
        return self.sequence_length // self.pool_factor

pool_factor property

Python
pool_factor: int

Total downsampling factor applied by the stem, i.e. base pairs per output bin.

num_bins property

Python
num_bins: int

Number of binned positions produced by the trunk before center-cropping.

EnformerForTokenPrediction

Bases: EnformerPreTrainedModel

Enformer with a pointwise regression head over genomic coverage tracks.

The binned positional axis is treated as the “token” axis: logits have shape (batch_size, target_length, num_labels) where num_labels is the number of coverage tracks of the configured species head.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import EnformerConfig, EnformerForTokenPrediction
>>> config = EnformerConfig(
...     sequence_length=256, hidden_size=12, num_hidden_layers=1, num_attention_heads=2,
...     attention_head_size=4, num_downsamples=3, dim_divisible_by=2, target_length=16,
...     num_labels=4,
... )
>>> model = EnformerForTokenPrediction(config)
>>> input_ids = torch.randint(config.vocab_size, (1, 256))
>>> output = model(input_ids, labels=torch.randn(1, 16, 4))
>>> output["logits"].shape
torch.Size([1, 16, 4])
Source code in multimolecule/models/enformer/modeling_enformer.py
Python
class EnformerForTokenPrediction(EnformerPreTrainedModel):
    """
    Enformer with a pointwise regression head over genomic coverage tracks.

    The binned positional axis is treated as the "token" axis: logits have shape
    `(batch_size, target_length, num_labels)` where `num_labels` is the number of coverage tracks
    of the configured `species` head.

    Examples:
        >>> import torch
        >>> from multimolecule import EnformerConfig, EnformerForTokenPrediction
        >>> config = EnformerConfig(
        ...     sequence_length=256, hidden_size=12, num_hidden_layers=1, num_attention_heads=2,
        ...     attention_head_size=4, num_downsamples=3, dim_divisible_by=2, target_length=16,
        ...     num_labels=4,
        ... )
        >>> model = EnformerForTokenPrediction(config)
        >>> input_ids = torch.randint(config.vocab_size, (1, 256))
        >>> output = model(input_ids, labels=torch.randn(1, 16, 4))
        >>> output["logits"].shape
        torch.Size([1, 16, 4])
    """

    def __init__(self, config: EnformerConfig):
        super().__init__(config)
        self.model = EnformerModel(config)
        token_head_config = HeadConfig(config.head) if config.head is not None else HeadConfig()
        if token_head_config.num_labels is None:
            token_head_config.num_labels = config.num_labels
        if token_head_config.hidden_size is None:
            token_head_config.hidden_size = config.head_hidden_size
        if token_head_config.problem_type is None:
            token_head_config.problem_type = "regression"
        if token_head_config.transform is None:
            token_head_config.transform = None
        if token_head_config.act is None:
            token_head_config.act = None
        self.token_head = TokenPredictionHead(config, token_head_config)
        self.head_config = self.token_head.config
        # Enformer applies softplus to the per-track predictions so coverage stays non-negative.
        self.output_act = _resolve_activation(config.output_act)
        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        id2label = getattr(self.config, "id2label", None)
        if id2label is not None:
            labels = [
                str(id2label.get(index, f"{self.config.species}_track_{index}"))
                for index in range(self.config.num_labels)
            ]
            if any(label != f"LABEL_{index}" for index, label in enumerate(labels)):
                return labels
        return [f"{self.config.species}_track_{index}" for index in range(self.config.num_labels)]

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Tuple[Tensor, ...] | TokenPredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        head_outputs = BaseModelOutput(last_hidden_state=outputs.last_hidden_state)

        # The binned axis has no special tokens; pass an all-ones mask so the shared head keeps
        # every bin.
        bin_mask = outputs.last_hidden_state.new_ones(outputs.last_hidden_state.shape[:2], dtype=torch.long)
        output = self.token_head(head_outputs, bin_mask, None, None)
        logits = output.logits
        if self.output_act is not None:
            logits = self.output_act(logits)

        loss = None
        if labels is not None:
            loss = self.token_head.criterion(logits, labels)

        return TokenPredictorOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

EnformerModel

Bases: EnformerPreTrainedModel

The bare Enformer backbone. Consumes a long DNA window and returns binned hidden states.

The positional axis of the output is binned: a window of config.sequence_length base pairs is downsampled by the convolution stem, processed by the Transformer trunk, and center-cropped so last_hidden_state has shape (batch_size, target_length, head_hidden_size).

Examples:

Python Console Session
>>> from multimolecule import EnformerConfig, EnformerModel
>>> config = EnformerConfig(
...     sequence_length=256, hidden_size=12, num_hidden_layers=1, num_attention_heads=2,
...     attention_head_size=4, num_downsamples=3, dim_divisible_by=2, target_length=16,
... )
>>> model = EnformerModel(config)
>>> import torch
>>> input_ids = torch.randint(config.vocab_size, (1, 256))
>>> output = model(input_ids)
>>> output["last_hidden_state"].shape
torch.Size([1, 16, 24])
Source code in multimolecule/models/enformer/modeling_enformer.py
Python
class EnformerModel(EnformerPreTrainedModel):
    """
    The bare Enformer backbone. Consumes a long DNA window and returns binned hidden states.

    The positional axis of the output is *binned*: a window of `config.sequence_length` base pairs
    is downsampled by the convolution stem, processed by the Transformer trunk, and center-cropped
    so `last_hidden_state` has shape `(batch_size, target_length, head_hidden_size)`.

    Examples:
        >>> from multimolecule import EnformerConfig, EnformerModel
        >>> config = EnformerConfig(
        ...     sequence_length=256, hidden_size=12, num_hidden_layers=1, num_attention_heads=2,
        ...     attention_head_size=4, num_downsamples=3, dim_divisible_by=2, target_length=16,
        ... )
        >>> model = EnformerModel(config)
        >>> import torch
        >>> input_ids = torch.randint(config.vocab_size, (1, 256))
        >>> output = model(input_ids)
        >>> output["last_hidden_state"].shape
        torch.Size([1, 16, 24])
    """

    def __init__(self, config: EnformerConfig):
        super().__init__(config)
        self.gradient_checkpointing = False
        self.embeddings = EnformerEmbedding(config)
        self.encoder = EnformerEncoder(config)
        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        encoder_outputs = self.encoder(embedding_output, **kwargs)

        return BaseModelOutput(
            last_hidden_state=encoder_outputs.last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
        )

EnformerPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/enformer/modeling_enformer.py
Python
class EnformerPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = EnformerConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["EnformerLayer", "EnformerConvLayer"]

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, nn.Conv1d):
            init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, nn.Linear):
            init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)
        elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm, nn.GroupNorm)):
            init.ones_(module.weight)
            init.zeros_(module.bias)
        elif isinstance(module, EnformerAttention):
            init.normal_(module.rel_content_bias)
            init.normal_(module.rel_pos_bias)
            nn.init.zeros_(module.to_out.weight)
            nn.init.zeros_(module.to_out.bias)
        elif isinstance(module, EnformerAttentionPool) and not getattr(
            module.to_attn_logits.weight, "_is_hf_initialized", False
        ):
            # `to_attn_logits` is a 1x1 Conv2d whose weight is a persistent parameter, so it is
            # restored from the checkpoint on `from_pretrained`. The guarded `init.dirac_` wrapper
            # respects `_is_hf_initialized` (no-op for loaded weights); we then scale by 2 to
            # reproduce the upstream average-pooling initialisation only when the weight was
            # actually (re)initialised here.
            init.dirac_(module.to_attn_logits.weight)
            with torch.no_grad():
                module.to_attn_logits.weight.mul_(2)