Skip to content

Borzoi

Sequence-to-coverage neural network for predicting RNA-seq and chromatin tracks across 524 kb DNA windows at 32 bp resolution.

Disclaimer

This is an UNOFFICIAL implementation of Predicting RNA-seq coverage from DNA sequence as a unifying model of gene regulation by Johannes Linder, Divyanshi Srivastava, Han Yuan, et al.

The OFFICIAL repository of Borzoi is at calico/borzoi.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing Borzoi did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

Borzoi is the successor of Enformer. It extends the Enformer recipe (convolution stem + Transformer trunk + binned multi-track output) to a 524,288 bp input window and 32 bp output bins, and adds a U-Net style upsampling tail so the binned positional axis matches a higher-resolution coverage prediction. A long DNA window of 524 kb is downsampled by a convolution stem and a width-growing residual convolution tower, projected to 1,536 channels by a U-Net bottleneck, processed by 8 Transformer blocks with Transformer-XL style relative positional encoding, then upsampled by two skip-connected U-Net stages with depthwise-separable convolutions, center-cropped to 6,144 bins, and projected to per-species coverage tracks with a softplus activation. The output is binned: it has shape (batch_size, target_length, num_tracks) where each bin summarizes 32 bp of sequence and num_tracks is the number of genomic coverage experiments for the selected species. Borzoi was trained jointly on RNA-seq, CAGE, ATAC-seq, DNase-seq, and ChIP-seq tracks. Please refer to the Training Details section for more information on the training process.

Variants

Borzoi releases human and mouse species heads.

Model Specification

Input Length Bin Size Output Bins Hidden Size Layers Heads Num Labels Num Parameters (M) FLOPs (P) MACs (P)
524288 32 6144 1536 8 8 7611 185.90 13.57 6.76

The table reports the human output head. The mouse head predicts 2,608 tracks. FLOPs and MACs are measured on the canonical 524,288 bp Borzoi input window.

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Genomic Coverage Prediction

You can use this model to predict binned RNA-seq and chromatin coverage tracks from a DNA sequence:

Python
>>> import torch
>>> from multimolecule import DnaTokenizer, BorzoiConfig, BorzoiForTokenPrediction

>>> config = BorzoiConfig(
...     sequence_length=512, hidden_size=16, num_hidden_layers=1, num_attention_heads=2,
...     attention_head_size=4, attention_value_size=4, num_rel_pos_features=4,
...     stem_channels=8, conv_tower_channels=[12], head_hidden_size=8, target_length=16,
...     num_labels=4,
... )
>>> model = BorzoiForTokenPrediction(config)
>>> output = model(torch.randint(config.vocab_size, (1, 512)))
>>> output.logits.shape
torch.Size([1, 16, 4])

The binned positional axis is treated as the “token” axis: each output position corresponds to one genomic bin rather than a single nucleotide. The species configuration option selects the human (7,611 tracks) or mouse (2,608 tracks) output head.

Interface

  • Input length: fixed 524,288 bp DNA window
  • Output binning: 32 bp per output bin; 6,144 output bins per window (after center-cropping the U-Net upsampling tail)
  • Species head: select human (7,611 tracks) or mouse (2,608 tracks) via the species config option
  • Output: (batch_size, target_length, num_tracks)

Training Details

Borzoi was trained to predict bulk RNA-seq coverage together with chromatin tracks (DNase-seq, ATAC-seq, ChIP-seq) and CAGE from the human and mouse reference genomes.

Training Data

The model was trained on a large compendium of functional genomics experiments aligned to the human (hg38) and mouse (mm10) reference genomes. The genome was divided into 524 kb windows; for each window the per-32-bp coverage of every experiment served as the regression target. The training set is dominated by RNA-seq coverage (the modality Borzoi extends over Enformer); the remaining tracks include the chromatin and CAGE modalities used by Enformer.

Training Procedure

Pre-training

The model was trained to minimize a Poisson-multinomial regression loss between predicted and observed coverage, using a softplus output activation to keep the predicted coverage non-negative. Training used the Adam optimizer with a warmup schedule and global gradient-norm clipping; reverse-complement and small genomic-shift data augmentations were applied during training.

Citation

BibTeX
@article{linder2025predicting,
  author    = {Linder, Johannes and Srivastava, Divyanshi and Yuan, Han and Agarwal, Vikram and Kelley, David R.},
  title     = {Predicting RNA-seq coverage from DNA sequence as a unifying model of gene regulation},
  journal   = {Nature Genetics},
  year      = 2025,
  volume    = 57,
  number    = 4,
  pages     = {949--961},
  doi       = {10.1038/s41588-024-02053-6},
  publisher = {Nature Publishing Group}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the Borzoi paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

API Reference

BorzoiConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a BorzoiModel. It is used to instantiate a Borzoi model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a configuration that reproduces the upstream Borzoi human architecture (calico/borzoi, examples/params_pred.json).

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Borzoi is the successor of Enformer. It extends the Enformer recipe (convolution stem + Transformer trunk + binned multi-track output) to a 524,288 bp input window and 32 bp output bins, and adds a U-Net style upsampling tail so the binned positional axis matches a higher-resolution coverage prediction. A long DNA window of sequence_length base pairs is downsampled by a convolution stem and a width-growing residual convolution tower, projected to hidden_size channels by a U-Net bottleneck, processed by the Transformer trunk with Transformer-XL style relative positional encoding, then upsampled by two skip-connected U-Net stages with depthwise-separable convolutions, center-cropped to target_length bins, and projected to per-species coverage tracks with a softplus activation. The output has shape (batch_size, target_length, num_labels).

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the Borzoi model. Defines the number of input feature channels derived from the MultiMolecule DNA token order. Defaults to 5 (A, C, G, T, N).

5

sequence_length

int

The length, in base pairs, of the input DNA window. Defaults to 524288 (= 512 kb).

524288

hidden_size

int

Dimensionality of the Transformer trunk and the U-Net upsampling tail.

1536

num_hidden_layers

int

Number of Transformer blocks in the trunk.

8

num_attention_heads

int

Number of attention heads in each Transformer block.

8

attention_head_size

int

Dimensionality of the query/key projection per head.

64

attention_value_size

int

Dimensionality of the value projection per head. Borzoi uses a larger value dim than key dim.

192

num_rel_pos_features

int

Number of relative positional features used by the Transformer-XL style attention.

32

stem_channels

int

Number of channels produced by the first (stem) convolution.

512

stem_kernel_size

int

Kernel size of the first (stem) convolution.

15

conv_tower_channels

list[int] | None

Explicit per-stage output channel schedule of the reducing convolution tower. Borzoi grows the width as 608, 736, 896, 1056, 1280; the tower length is len(conv_tower_channels).

None

conv_tower_kernel_size

int

Kernel size used by every convolution in the reducing tower.

5

unet_kernel_size

int

Kernel size of the depthwise-separable convolutions in the U-Net upsampling tail.

3

head_hidden_size

int

Channel count of the final pointwise convolution block feeding the per-species track head.

1920

hidden_act

str

The non-linear activation used throughout the convolution blocks. Borzoi uses the tanh-approximation GELU (gelu_new).

'gelu_new'

output_act

str

Activation applied to the per-track predictions. Borzoi applies softplus so the predicted coverage is non-negative.

'softplus'

hidden_dropout

float

Dropout probability of the final pointwise convolution block.

0.1

intermediate_dropout

float

Dropout probability applied inside the Transformer feed-forward sublayer.

0.2

attention_dropout

float

Dropout probability applied to the attention matrix.

0.05

position_dropout

float

Dropout probability applied to the relative positional features.

0.01

batch_norm_eps

float

Epsilon used by the batch normalization layers.

0.001

batch_norm_momentum

float

Momentum used by the batch normalization layers (PyTorch convention; upstream Keras momentum 0.9 corresponds to PyTorch momentum 0.1).

0.1

species

str

Output head to expose downstream. Borzoi is trained with two species heads; the selected head determines num_labels. Use human (7611 tracks) or mouse (2608 tracks).

'human'

target_length

int

Number of output bins kept after center-cropping the U-Net output. Defaults to 6144 (the bins_to_return setting of the upstream Borzoi inference path).

6144

num_labels

int | None

Number of genomic coverage tracks predicted per bin. Defaults to the track count of the selected species head.

None

head

HeadConfig | None

Head configuration for the binned track prediction head.

None

output_contexts

bool

Whether to output the context vectors for each trunk block.

False

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import BorzoiConfig, BorzoiModel
>>> # Initializing a Borzoi multimolecule/borzoi style configuration
>>> configuration = BorzoiConfig()
>>> # Initializing a model (with random weights) from the multimolecule/borzoi style configuration
>>> model = BorzoiModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/borzoi/configuration_borzoi.py
Python
class BorzoiConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`BorzoiModel`][multimolecule.models.BorzoiModel]. It is used to instantiate a Borzoi model according to the
    specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a
    configuration that reproduces the upstream Borzoi human architecture
    ([calico/borzoi](https://github.com/calico/borzoi), `examples/params_pred.json`).

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Borzoi is the successor of Enformer. It extends the Enformer recipe (convolution stem + Transformer trunk +
    binned multi-track output) to a 524,288 bp input window and 32 bp output bins, and adds a U-Net style
    upsampling tail so the binned positional axis matches a higher-resolution coverage prediction. A long DNA
    window of `sequence_length` base pairs is downsampled by a convolution stem and a width-growing residual
    convolution tower, projected to `hidden_size` channels by a U-Net bottleneck, processed by the Transformer
    trunk with Transformer-XL style relative positional encoding, then upsampled by two skip-connected U-Net
    stages with depthwise-separable convolutions, center-cropped to `target_length` bins, and projected to
    per-species coverage tracks with a softplus activation. The output has shape
    `(batch_size, target_length, num_labels)`.

    Args:
        vocab_size:
            Vocabulary size of the Borzoi model. Defines the number of input feature channels derived from the
            MultiMolecule DNA token order.
            Defaults to 5 (`A`, `C`, `G`, `T`, `N`).
        sequence_length:
            The length, in base pairs, of the input DNA window.
            Defaults to 524288 (= 512 kb).
        hidden_size:
            Dimensionality of the Transformer trunk and the U-Net upsampling tail.
        num_hidden_layers:
            Number of Transformer blocks in the trunk.
        num_attention_heads:
            Number of attention heads in each Transformer block.
        attention_head_size:
            Dimensionality of the query/key projection per head.
        attention_value_size:
            Dimensionality of the value projection per head. Borzoi uses a larger value dim than key dim.
        num_rel_pos_features:
            Number of relative positional features used by the Transformer-XL style attention.
        stem_channels:
            Number of channels produced by the first (stem) convolution.
        stem_kernel_size:
            Kernel size of the first (stem) convolution.
        conv_tower_channels:
            Explicit per-stage output channel schedule of the reducing convolution tower. Borzoi grows the
            width as ``608, 736, 896, 1056, 1280``; the tower length is ``len(conv_tower_channels)``.
        conv_tower_kernel_size:
            Kernel size used by every convolution in the reducing tower.
        unet_kernel_size:
            Kernel size of the depthwise-separable convolutions in the U-Net upsampling tail.
        head_hidden_size:
            Channel count of the final pointwise convolution block feeding the per-species track head.
        hidden_act:
            The non-linear activation used throughout the convolution blocks. Borzoi uses the tanh-approximation
            GELU (`gelu_new`).
        output_act:
            Activation applied to the per-track predictions. Borzoi applies `softplus` so the predicted coverage
            is non-negative.
        hidden_dropout:
            Dropout probability of the final pointwise convolution block.
        intermediate_dropout:
            Dropout probability applied inside the Transformer feed-forward sublayer.
        attention_dropout:
            Dropout probability applied to the attention matrix.
        position_dropout:
            Dropout probability applied to the relative positional features.
        batch_norm_eps:
            Epsilon used by the batch normalization layers.
        batch_norm_momentum:
            Momentum used by the batch normalization layers (PyTorch convention; upstream Keras momentum 0.9
            corresponds to PyTorch momentum 0.1).
        species:
            Output head to expose downstream. Borzoi is trained with two species heads; the selected head
            determines `num_labels`. Use `human` (7611 tracks) or `mouse` (2608 tracks).
        target_length:
            Number of output bins kept after center-cropping the U-Net output. Defaults to 6144 (the
            `bins_to_return` setting of the upstream Borzoi inference path).
        num_labels:
            Number of genomic coverage tracks predicted per bin. Defaults to the track count of the selected
            `species` head.
        head:
            Head configuration for the binned track prediction head.
        output_contexts:
            Whether to output the context vectors for each trunk block.

    Examples:
        >>> from multimolecule import BorzoiConfig, BorzoiModel
        >>> # Initializing a Borzoi multimolecule/borzoi style configuration
        >>> configuration = BorzoiConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/borzoi style configuration
        >>> model = BorzoiModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "borzoi"

    species_num_tracks = {"human": 7611, "mouse": 2608}

    def __init__(
        self,
        vocab_size: int = 5,
        sequence_length: int = 524288,
        hidden_size: int = 1536,
        num_hidden_layers: int = 8,
        num_attention_heads: int = 8,
        attention_head_size: int = 64,
        attention_value_size: int = 192,
        num_rel_pos_features: int = 32,
        stem_channels: int = 512,
        stem_kernel_size: int = 15,
        conv_tower_channels: list[int] | None = None,
        conv_tower_kernel_size: int = 5,
        unet_kernel_size: int = 3,
        head_hidden_size: int = 1920,
        hidden_act: str = "gelu_new",
        output_act: str = "softplus",
        hidden_dropout: float = 0.1,
        intermediate_dropout: float = 0.2,
        attention_dropout: float = 0.05,
        position_dropout: float = 0.01,
        batch_norm_eps: float = 1e-3,
        batch_norm_momentum: float = 0.1,
        species: str = "human",
        target_length: int = 6144,
        num_labels: int | None = None,
        head: HeadConfig | None = None,
        output_contexts: bool = False,
        **kwargs,
    ):
        # Borzoi is a feature-channel DNA model: it consumes a raw one-hot DNA window with no special tokens,
        # and its output is binned coverage tracks. There is no BOS/EOS/MASK token on either the input or the
        # binned positional axis, so the shared TokenPredictionHead must not trim "special" bins.
        kwargs.setdefault("bos_token_id", None)
        kwargs.setdefault("eos_token_id", None)
        kwargs.setdefault("mask_token_id", None)
        kwargs.setdefault("null_token_id", None)
        if species not in self.species_num_tracks:
            raise ValueError(f"species must be one of {sorted(self.species_num_tracks)}, got {species!r}")
        if num_labels is None:
            num_labels = self.species_num_tracks[species]
        super().__init__(num_labels=num_labels, **kwargs)
        self.vocab_size = vocab_size
        self.sequence_length = sequence_length
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = attention_head_size
        self.attention_value_size = attention_value_size
        self.num_rel_pos_features = num_rel_pos_features
        self.stem_channels = stem_channels
        self.stem_kernel_size = stem_kernel_size
        if conv_tower_channels is None:
            conv_tower_channels = [608, 736, 896, 1056, 1280]
        self.conv_tower_channels = list(conv_tower_channels)
        self.conv_tower_kernel_size = conv_tower_kernel_size
        self.unet_kernel_size = unet_kernel_size
        self.head_hidden_size = head_hidden_size
        self.hidden_act = hidden_act
        self.output_act = output_act
        self.hidden_dropout = hidden_dropout
        self.intermediate_dropout = intermediate_dropout
        self.attention_dropout = attention_dropout
        self.position_dropout = position_dropout
        self.batch_norm_eps = batch_norm_eps
        self.batch_norm_momentum = batch_norm_momentum
        self.species = species
        self.target_length = target_length
        if head is None:
            head = HeadConfig(problem_type="regression")
        else:
            head = HeadConfig(head)
            if head.problem_type is None:
                head.problem_type = "regression"
        self.head = head
        self.output_contexts = output_contexts

        if self.stem_channels < 1:
            raise ValueError(f"stem_channels must be >= 1, got {self.stem_channels}")
        if any(c < 1 for c in self.conv_tower_channels):
            raise ValueError(f"conv_tower_channels must be positive, got {self.conv_tower_channels}")
        if self.hidden_size < 1:
            raise ValueError(f"hidden_size must be >= 1, got {self.hidden_size}")
        if self.num_attention_heads < 1:
            raise ValueError(f"num_attention_heads must be >= 1, got {self.num_attention_heads}")
        if self.target_length <= 0 and self.target_length != -1:
            raise ValueError(f"target_length must be positive (or -1 to skip cropping), got {self.target_length}")
        if self.pool_factor <= 0:
            raise ValueError(f"pool_factor must be positive, got {self.pool_factor}")
        if self.sequence_length % self.pool_factor != 0:
            raise ValueError(
                f"sequence_length ({self.sequence_length}) must be divisible by the total pooling factor "
                f"({self.pool_factor}) so the binned output is well defined."
            )
        if self.target_length != -1 and self.target_length > self.num_output_bins:
            raise ValueError(
                f"target_length ({self.target_length}) must not exceed the number of binned positions after "
                f"the U-Net upsampling tail ({self.num_output_bins})."
            )

    @property
    def num_downsamples(self) -> int:
        r"""Number of 2x downsampling stages: stem + tower + U-Net bottleneck + final pool."""
        # conv_dna pool (1) + one pool per reducing-tower stage except the last (len-1)
        # + unet1 pool (1) + final pool before transformer (1).
        return 1 + max(0, len(self.conv_tower_channels) - 1) + 2

    @property
    def num_upsamples(self) -> int:
        r"""Number of 2x upsampling stages in the U-Net tail."""
        return 2

    @property
    def pool_factor(self) -> int:
        r"""Total downsampling factor at the transformer trunk, i.e. base pairs per attention position."""
        return 2**self.num_downsamples

    @property
    def output_bin_size(self) -> int:
        r"""Base pairs per output bin, after the U-Net upsampling tail."""
        return 2 ** (self.num_downsamples - self.num_upsamples)

    @property
    def num_output_bins(self) -> int:
        r"""Number of binned positions produced by the U-Net upsampling tail before center-cropping."""
        return self.sequence_length // self.output_bin_size

num_downsamples property

Python
num_downsamples: int

Number of 2x downsampling stages: stem + tower + U-Net bottleneck + final pool.

num_upsamples property

Python
num_upsamples: int

Number of 2x upsampling stages in the U-Net tail.

pool_factor property

Python
pool_factor: int

Total downsampling factor at the transformer trunk, i.e. base pairs per attention position.

output_bin_size property

Python
output_bin_size: int

Base pairs per output bin, after the U-Net upsampling tail.

num_output_bins property

Python
num_output_bins: int

Number of binned positions produced by the U-Net upsampling tail before center-cropping.

BorzoiForTokenPrediction

Bases: BorzoiPreTrainedModel

Borzoi with a pointwise regression head over genomic coverage tracks.

The binned positional axis is treated as the “token” axis: logits have shape (batch_size, target_length, num_labels) where num_labels is the number of coverage tracks of the configured species head.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import BorzoiConfig, BorzoiForTokenPrediction
>>> config = BorzoiConfig(
...     sequence_length=512, hidden_size=16, num_hidden_layers=1, num_attention_heads=2,
...     attention_head_size=4, attention_value_size=4, num_rel_pos_features=4,
...     stem_channels=8, conv_tower_channels=[12], head_hidden_size=8, target_length=16,
...     num_labels=4,
... )
>>> model = BorzoiForTokenPrediction(config)
>>> input_ids = torch.randint(config.vocab_size, (1, 512))
>>> output = model(input_ids, labels=torch.randn(1, 16, 4))
>>> output["logits"].shape
torch.Size([1, 16, 4])
Source code in multimolecule/models/borzoi/modeling_borzoi.py
Python
class BorzoiForTokenPrediction(BorzoiPreTrainedModel):
    """
    Borzoi with a pointwise regression head over genomic coverage tracks.

    The binned positional axis is treated as the "token" axis: logits have shape
    `(batch_size, target_length, num_labels)` where `num_labels` is the number of coverage tracks
    of the configured `species` head.

    Examples:
        >>> import torch
        >>> from multimolecule import BorzoiConfig, BorzoiForTokenPrediction
        >>> config = BorzoiConfig(
        ...     sequence_length=512, hidden_size=16, num_hidden_layers=1, num_attention_heads=2,
        ...     attention_head_size=4, attention_value_size=4, num_rel_pos_features=4,
        ...     stem_channels=8, conv_tower_channels=[12], head_hidden_size=8, target_length=16,
        ...     num_labels=4,
        ... )
        >>> model = BorzoiForTokenPrediction(config)
        >>> input_ids = torch.randint(config.vocab_size, (1, 512))
        >>> output = model(input_ids, labels=torch.randn(1, 16, 4))
        >>> output["logits"].shape
        torch.Size([1, 16, 4])
    """

    def __init__(self, config: BorzoiConfig):
        super().__init__(config)
        self.model = BorzoiModel(config)
        token_head_config = HeadConfig(config.head) if config.head is not None else HeadConfig()
        if token_head_config.num_labels is None:
            token_head_config.num_labels = config.num_labels
        if token_head_config.hidden_size is None:
            token_head_config.hidden_size = config.head_hidden_size
        if token_head_config.problem_type is None:
            token_head_config.problem_type = "regression"
        if token_head_config.transform is None:
            token_head_config.transform = None
        if token_head_config.act is None:
            token_head_config.act = None
        self.token_head = TokenPredictionHead(config, token_head_config)
        self.head_config = self.token_head.config
        # Borzoi applies softplus to the per-track predictions so coverage stays non-negative.
        self.output_act = _resolve_activation(config.output_act)
        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        id2label = getattr(self.config, "id2label", None)
        if id2label is not None:
            labels = [
                str(id2label.get(index, f"{self.config.species}_track_{index}"))
                for index in range(self.config.num_labels)
            ]
            if any(label != f"LABEL_{index}" for index, label in enumerate(labels)):
                return labels
        return [f"{self.config.species}_track_{index}" for index in range(self.config.num_labels)]

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[Tensor, ...] | TokenPredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        head_outputs = BaseModelOutput(last_hidden_state=outputs.last_hidden_state)

        # The binned axis has no special tokens; pass an all-ones mask so the shared head keeps every bin.
        bin_mask = outputs.last_hidden_state.new_ones(outputs.last_hidden_state.shape[:2], dtype=torch.long)
        output = self.token_head(head_outputs, bin_mask, None, None)
        logits = output.logits
        if self.output_act is not None:
            logits = self.output_act(logits)

        loss = None
        if labels is not None:
            loss = self.token_head.criterion(logits, labels)

        return TokenPredictorOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

BorzoiModel

Bases: BorzoiPreTrainedModel

The bare Borzoi backbone. Consumes a long DNA window and returns binned hidden states.

The architecture follows the upstream Borzoi trunk: a pre-activation convolution stem with attention-pool downsampling, a width-growing residual convolution tower, a U-Net bottleneck pool, a Transformer trunk with Transformer-XL style relative positional encoding, two U-Net upsampling stages with depthwise-separable convolutions, and a center-crop. The positional axis of the output is binned: a window of config.sequence_length base pairs is downsampled and then re-upsampled, and last_hidden_state has shape (batch_size, target_length, head_hidden_size).

Examples:

Python Console Session
>>> from multimolecule import BorzoiConfig, BorzoiModel
>>> config = BorzoiConfig(
...     sequence_length=512, hidden_size=16, num_hidden_layers=1, num_attention_heads=2,
...     attention_head_size=4, attention_value_size=4, num_rel_pos_features=4,
...     stem_channels=8, conv_tower_channels=[12], head_hidden_size=8, target_length=16,
... )
>>> model = BorzoiModel(config)
>>> import torch
>>> input_ids = torch.randint(config.vocab_size, (1, 512))
>>> output = model(input_ids)
>>> output["last_hidden_state"].shape
torch.Size([1, 16, 8])
Source code in multimolecule/models/borzoi/modeling_borzoi.py
Python
class BorzoiModel(BorzoiPreTrainedModel):
    """
    The bare Borzoi backbone. Consumes a long DNA window and returns binned hidden states.

    The architecture follows the upstream Borzoi trunk: a pre-activation convolution stem with attention-pool
    downsampling, a width-growing residual convolution tower, a U-Net bottleneck pool, a Transformer trunk
    with Transformer-XL style relative positional encoding, two U-Net upsampling stages with depthwise-separable
    convolutions, and a center-crop. The positional axis of the output is *binned*: a window of
    `config.sequence_length` base pairs is downsampled and then re-upsampled, and `last_hidden_state` has shape
    `(batch_size, target_length, head_hidden_size)`.

    Examples:
        >>> from multimolecule import BorzoiConfig, BorzoiModel
        >>> config = BorzoiConfig(
        ...     sequence_length=512, hidden_size=16, num_hidden_layers=1, num_attention_heads=2,
        ...     attention_head_size=4, attention_value_size=4, num_rel_pos_features=4,
        ...     stem_channels=8, conv_tower_channels=[12], head_hidden_size=8, target_length=16,
        ... )
        >>> model = BorzoiModel(config)
        >>> import torch
        >>> input_ids = torch.randint(config.vocab_size, (1, 512))
        >>> output = model(input_ids)
        >>> output["last_hidden_state"].shape
        torch.Size([1, 16, 8])
    """

    def __init__(self, config: BorzoiConfig):
        super().__init__(config)
        self.gradient_checkpointing = False
        self.embeddings = BorzoiEmbedding(config)
        self.encoder = BorzoiEncoder(config)
        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        encoder_outputs = self.encoder(embedding_output, **kwargs)

        return BaseModelOutput(
            last_hidden_state=encoder_outputs.last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
        )

BorzoiPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/borzoi/modeling_borzoi.py
Python
class BorzoiPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BorzoiConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["BorzoiLayer", "BorzoiConvLayer"]

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, nn.Conv1d):
            init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, nn.Linear):
            init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)
        elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm, nn.GroupNorm)):
            init.ones_(module.weight)
            init.zeros_(module.bias)
        elif isinstance(module, BorzoiAttention):
            init.normal_(module.rel_content_bias)
            init.normal_(module.rel_pos_bias)
            nn.init.zeros_(module.to_out.weight)
            nn.init.zeros_(module.to_out.bias)