Skip to content

Basenji

Deep convolutional neural network for predicting genomic coverage tracks across chromosomes.

Disclaimer

This is an UNOFFICIAL implementation of Sequential regulatory activity prediction across chromosomes with deep convolutional and recurrent neural networks by David R. Kelley, Yakir A. Reshef, et al.

The OFFICIAL repository of Basenji is at calico/basenji.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing Basenji did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

Basenji is a deep convolutional neural network trained to predict genomic regulatory activity from long DNA sequences. It consumes a long DNA window (~131 kb), passes it through a convolution + pooling stem that downsamples the sequence, and then through a tower of dilated residual convolutional blocks that expand the receptive field. A pointwise output head predicts a vector of genomic coverage tracks for each output bin. Because the stem downsamples the input, the prediction is binned: the output has shape (batch_size, num_bins, num_tracks) where each bin summarizes 128 bp of sequence and num_tracks is the number of genomic coverage experiments.

Model Specification

Input Length Bin Size Output Bins Hidden Size Dilated Blocks Num Labels Num Parameters (M) FLOPs (G) MACs (G) Max Num Tokens
131,072 128 896 768 11 5,313 30.09 234.85 117.19 131,072

FLOPs and MACs are measured on the canonical 131,072 bp Basenji input window.

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Genomic Coverage Prediction

You can use this model to predict binned genomic coverage tracks from a DNA sequence:

Python
>>> import torch
>>> from multimolecule import DnaTokenizer, BasenjiConfig, BasenjiForTokenPrediction

>>> config = BasenjiConfig(
...     sequence_length=256, stem_channels=8, conv_tower_channels=[8],
...     stem_pool_size=2, head_hidden_size=8, crop_bins=2, num_labels=4,
...     blocks={"num_blocks": 1, "kernel_size": 3, "bottleneck_size": 4},
... )
>>> model = BasenjiForTokenPrediction(config)
>>> output = model(torch.randint(config.vocab_size, (1, 256)))
>>> output.logits.shape
torch.Size([1, 60, 4])
>>> coverage, channels = model.postprocess(output)
>>> coverage.shape
torch.Size([1, 60, 4])

The binned positional axis is treated as the “token” axis: each output position corresponds to one genomic bin rather than a single nucleotide.

Interface

  • Input length: fixed 131,072 bp DNA window
  • Output binning: 128 bp per output bin; 896 output bins per window (after Cropping1D(64) on each side)
  • Output: raw pre-softplus logits of shape (batch_size, num_bins, num_tracks); use postprocess for non-negative coverage tracks

Training Details

Basenji was trained to predict genomic coverage tracks (DNase-seq, ATAC-seq, ChIP-seq and CAGE) from the human and mouse reference genomes.

Training Data

The model was trained on a large compendium of functional genomics experiments aligned to the human (hg38) and mouse (mm10) reference genomes. The genome was divided into overlapping windows; for each window the per-128-bp coverage of every experiment served as the regression target.

Training Procedure

Pre-training

The model was trained to minimize a Poisson regression loss between predicted and observed coverage.

Citation

BibTeX
@article{kelley2018sequential,
  author    = {Kelley, David R. and Reshef, Yakir A. and Bileschi, Maxwell and Belanger, David and McLean, Cory Y. and Snoek, Jasper},
  title     = {Sequential regulatory activity prediction across chromosomes with deep convolutional and recurrent neural networks},
  journal   = {Genome Research},
  year      = 2018,
  volume    = 28,
  number    = 5,
  pages     = {739--750},
  doi       = {10.1101/gr.227819.117},
  publisher = {Cold Spring Harbor Laboratory}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If MultiMolecule supports your research, please cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the Basenji paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

API Reference

BasenjiBlockConfig

Bases: FlatDict

Configuration for the dilated residual tower of the Basenji2 trunk.

Basenji2 stacks num_blocks dilated residual units. Each unit runs on a hidden_size-channel residual stream and internally bottlenecks to bottleneck_size channels for the dilated convolution before projecting back. The dilation factor starts at dilation and is multiplied by dilation_rate after every block (rounded to the nearest integer when round_dilation is set), which is how Basenji2 reaches the receptive field needed for ~131 kb input windows.

Parameters:

Name Type Description Default

num_blocks

Number of dilated residual blocks in the tower.

required

kernel_size

Kernel size of the dilated (bottleneck) convolution.

required

bottleneck_size

Channel count of the dilated convolution bottleneck.

required

dilation

Dilation factor of the first block.

required

dilation_rate

Multiplicative factor applied to the dilation after each block.

required

round_dilation

Whether to round the running dilation to the nearest integer after each multiply (upstream Basenji2 uses round=true).

required

dropout

Dropout probability applied to the projected (return) convolution of every block.

required
Source code in multimolecule/models/basenji/configuration_basenji.py
Python
class BasenjiBlockConfig(FlatDict):
    r"""
    Configuration for the dilated residual tower of the Basenji2 trunk.

    Basenji2 stacks `num_blocks` dilated residual units. Each unit runs on a `hidden_size`-channel
    residual stream and internally bottlenecks to `bottleneck_size` channels for the dilated
    convolution before projecting back. The dilation factor starts at `dilation` and is multiplied
    by `dilation_rate` after every block (rounded to the nearest integer when `round_dilation` is
    set), which is how Basenji2 reaches the receptive field needed for ~131 kb input windows.

    Args:
        num_blocks:
            Number of dilated residual blocks in the tower.
        kernel_size:
            Kernel size of the dilated (bottleneck) convolution.
        bottleneck_size:
            Channel count of the dilated convolution bottleneck.
        dilation:
            Dilation factor of the first block.
        dilation_rate:
            Multiplicative factor applied to the dilation after each block.
        round_dilation:
            Whether to round the running dilation to the nearest integer after each multiply
            (upstream Basenji2 uses `round=true`).
        dropout:
            Dropout probability applied to the projected (return) convolution of every block.
    """

    num_blocks: int = 11
    kernel_size: int = 3
    bottleneck_size: int = 384
    dilation: int = 1
    dilation_rate: float = 1.5
    round_dilation: bool = True
    dropout: float = 0.3

BasenjiConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a BasenjiModel. It is used to instantiate a Basenji model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a configuration that faithfully reproduces the upstream Basenji2 human graph (calico/basenji, manuscripts/cross2020/params_human.json).

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Basenji2 predicts genomic coverage tracks at a binned resolution. A long DNA window of sequence_length base pairs is downsampled by the convolution + pooling stem and tower, then cropped by crop_bins bins on each side, so the output has shape (batch_size, num_bins, num_labels) where num_labels is the number of coverage tracks.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the Basenji model. Defines the number of input feature channels derived from the MultiMolecule DNA token order. Defaults to 5 (A, C, G, T, N).

5

sequence_length

int

The length, in base pairs, of the input DNA window. Defaults to 131072 (~131 kb).

131072

stem_channels

int

Number of channels produced by the first (stem) convolution.

288

stem_kernel_size

int

Kernel size of the first (stem) convolution.

15

stem_pool_size

int

Pooling size applied after every convolution block in the stem and tower.

2

conv_tower_channels

list[int] | None

Explicit per-stage output channel schedule of the reducing convolution tower. Basenji2 grows the width as 339, 399, 470, 554, 652, 768; the tower length is len(conv_tower_channels) and each stage halves the resolution.

None

conv_tower_kernel_size

int

Kernel size used by every convolution in the reducing tower.

5

blocks

BasenjiBlockConfig | None

Configuration of the dilated residual tower. A single [BasenjiBlockConfig].

None

crop_bins

int

Number of bins trimmed from each side of the binned axis after the dilated tower (upstream Cropping1D).

64

head_hidden_size

int

Channel count of the final pointwise convolution block feeding the track head.

1536

hidden_act

str

The non-linear activation used throughout the network. Basenji2 uses the tanh-approximation GELU (gelu_new).

'gelu_new'

output_act

str

The activation applied to the final track projection. Basenji2 uses softplus.

'softplus'

hidden_dropout

float

Dropout probability of the final pointwise convolution block.

0.05

batch_norm_eps

float

The epsilon used by the batch normalization layers.

0.001

batch_norm_momentum

float

The momentum used by the batch normalization layers (PyTorch convention; upstream Keras momentum 0.9 corresponds to PyTorch momentum 0.1).

0.1

num_labels

int

Number of genomic coverage tracks predicted per bin. Defaults to 5313 (the human track set released with Basenji2).

5313

head

HeadConfig | None

The configuration of the binned track prediction head. Defaults to a regression head (problem_type="regression"), matching Basenji’s genomic coverage prediction task.

None

output_contexts

bool

Whether to output the context vectors for each tower block.

False

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import BasenjiConfig, BasenjiModel
>>> # Initializing a Basenji multimolecule/basenji style configuration
>>> configuration = BasenjiConfig()
>>> # Initializing a model (with random weights) from the multimolecule/basenji style configuration
>>> model = BasenjiModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/basenji/configuration_basenji.py
Python
class BasenjiConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`BasenjiModel`][multimolecule.models.BasenjiModel]. It is used to instantiate a Basenji model according to the
    specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a
    configuration that faithfully reproduces the upstream Basenji2 human graph
    ([calico/basenji](https://github.com/calico/basenji), `manuscripts/cross2020/params_human.json`).

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Basenji2 predicts genomic coverage tracks at a *binned* resolution. A long DNA window of
    `sequence_length` base pairs is downsampled by the convolution + pooling stem and tower, then
    cropped by `crop_bins` bins on each side, so the output has shape
    `(batch_size, num_bins, num_labels)` where `num_labels` is the number of coverage tracks.

    Args:
        vocab_size:
            Vocabulary size of the Basenji model. Defines the number of input feature channels
            derived from the MultiMolecule DNA token order.
            Defaults to 5 (`A`, `C`, `G`, `T`, `N`).
        sequence_length:
            The length, in base pairs, of the input DNA window.
            Defaults to 131072 (~131 kb).
        stem_channels:
            Number of channels produced by the first (stem) convolution.
        stem_kernel_size:
            Kernel size of the first (stem) convolution.
        stem_pool_size:
            Pooling size applied after every convolution block in the stem and tower.
        conv_tower_channels:
            Explicit per-stage output channel schedule of the reducing convolution tower. Basenji2
            grows the width as ``339, 399, 470, 554, 652, 768``; the tower length is
            ``len(conv_tower_channels)`` and each stage halves the resolution.
        conv_tower_kernel_size:
            Kernel size used by every convolution in the reducing tower.
        blocks:
            Configuration of the dilated residual tower. A single [`BasenjiBlockConfig`].
        crop_bins:
            Number of bins trimmed from *each* side of the binned axis after the dilated tower
            (upstream `Cropping1D`).
        head_hidden_size:
            Channel count of the final pointwise convolution block feeding the track head.
        hidden_act:
            The non-linear activation used throughout the network. Basenji2 uses the
            tanh-approximation GELU (`gelu_new`).
        output_act:
            The activation applied to the final track projection. Basenji2 uses `softplus`.
        hidden_dropout:
            Dropout probability of the final pointwise convolution block.
        batch_norm_eps:
            The epsilon used by the batch normalization layers.
        batch_norm_momentum:
            The momentum used by the batch normalization layers (PyTorch convention; upstream Keras
            momentum 0.9 corresponds to PyTorch momentum 0.1).
        num_labels:
            Number of genomic coverage tracks predicted per bin.
            Defaults to 5313 (the human track set released with Basenji2).
        head:
            The configuration of the binned track prediction head. Defaults to a regression head
            (`problem_type="regression"`), matching Basenji's genomic coverage prediction task.
        output_contexts:
            Whether to output the context vectors for each tower block.

    Examples:
        >>> from multimolecule import BasenjiConfig, BasenjiModel
        >>> # Initializing a Basenji multimolecule/basenji style configuration
        >>> configuration = BasenjiConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/basenji style configuration
        >>> model = BasenjiModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "basenji"

    def __init__(
        self,
        vocab_size: int = 5,
        sequence_length: int = 131072,
        stem_channels: int = 288,
        stem_kernel_size: int = 15,
        stem_pool_size: int = 2,
        conv_tower_channels: list[int] | None = None,
        conv_tower_kernel_size: int = 5,
        blocks: BasenjiBlockConfig | None = None,
        crop_bins: int = 64,
        head_hidden_size: int = 1536,
        hidden_act: str = "gelu_new",
        output_act: str = "softplus",
        hidden_dropout: float = 0.05,
        batch_norm_eps: float = 1e-3,
        batch_norm_momentum: float = 0.1,
        num_labels: int = 5313,
        head: HeadConfig | None = None,
        output_contexts: bool = False,
        **kwargs,
    ):
        # Basenji is a feature-channel DNA model: it consumes a raw one-hot DNA window with no
        # special tokens, and its output is binned coverage tracks. There is no BOS/EOS/MASK token
        # on either the input or the binned positional axis, so the shared TokenPredictionHead must
        # not trim "special" bins.
        kwargs.setdefault("bos_token_id", None)
        kwargs.setdefault("eos_token_id", None)
        kwargs.setdefault("mask_token_id", None)
        kwargs.setdefault("null_token_id", None)
        super().__init__(num_labels=num_labels, **kwargs)
        self.vocab_size = vocab_size
        self.sequence_length = sequence_length
        self.stem_channels = stem_channels
        self.stem_kernel_size = stem_kernel_size
        self.stem_pool_size = stem_pool_size
        if conv_tower_channels is None:
            conv_tower_channels = [339, 399, 470, 554, 652, 768]
        self.conv_tower_channels = list(conv_tower_channels)
        self.conv_tower_kernel_size = conv_tower_kernel_size
        if blocks is None:
            blocks = BasenjiBlockConfig()
        self.blocks = blocks if isinstance(blocks, BasenjiBlockConfig) else BasenjiBlockConfig(**dict(blocks))
        self.crop_bins = crop_bins
        self.head_hidden_size = head_hidden_size
        self.hidden_act = hidden_act
        self.output_act = output_act
        self.hidden_dropout = hidden_dropout
        self.batch_norm_eps = batch_norm_eps
        self.batch_norm_momentum = batch_norm_momentum
        if head is None:
            head = HeadConfig(problem_type="regression")
        else:
            head = HeadConfig(head)
            if head.problem_type is None:
                head.problem_type = "regression"
        self.head = head
        self.output_contexts = output_contexts

        if self.stem_pool_size < 1:
            raise ValueError(f"stem_pool_size must be >= 1, got {self.stem_pool_size}")
        if self.stem_channels < 1:
            raise ValueError(f"stem_channels must be >= 1, got {self.stem_channels}")
        if any(c < 1 for c in self.conv_tower_channels):
            raise ValueError(f"conv_tower_channels must be positive, got {self.conv_tower_channels}")
        if self.blocks.bottleneck_size < 1:
            raise ValueError(f"blocks.bottleneck_size must be >= 1, got {self.blocks.bottleneck_size}")
        if self.crop_bins < 0:
            raise ValueError(f"crop_bins must be >= 0, got {self.crop_bins}")
        if self.pool_factor <= 0:
            raise ValueError(f"pool_factor must be positive, got {self.pool_factor}")
        if self.sequence_length % self.pool_factor != 0:
            raise ValueError(
                f"sequence_length ({self.sequence_length}) must be divisible by the total pooling "
                f"factor ({self.pool_factor}) so the binned output is well defined."
            )
        if self.num_bins <= 0:
            raise ValueError(
                f"crop_bins ({self.crop_bins}) trims the entire binned axis "
                f"(pre-crop bins = {self.sequence_length // self.pool_factor}); reduce crop_bins."
            )

    @property
    def num_pool_layers(self) -> int:
        r"""Number of pooling stages: the stem block plus every reducing-tower stage."""
        return 1 + len(self.conv_tower_channels)

    @property
    def pool_factor(self) -> int:
        r"""Total downsampling factor applied by the stem and tower, i.e. base pairs per bin."""
        return self.stem_pool_size**self.num_pool_layers

    @property
    def hidden_size(self) -> int:
        r"""Channel count of the dilated residual stream."""
        return self.conv_tower_channels[-1] if self.conv_tower_channels else self.stem_channels

    @property
    def num_bins(self) -> int:
        r"""Number of output bins along the positional (token) axis, after cropping."""
        return self.sequence_length // self.pool_factor - 2 * self.crop_bins

num_pool_layers property

Python
num_pool_layers: int

Number of pooling stages: the stem block plus every reducing-tower stage.

pool_factor property

Python
pool_factor: int

Total downsampling factor applied by the stem and tower, i.e. base pairs per bin.

hidden_size property

Python
hidden_size: int

Channel count of the dilated residual stream.

num_bins property

Python
num_bins: int

Number of output bins along the positional (token) axis, after cropping.

BasenjiForTokenPrediction

Bases: BasenjiPreTrainedModel

Basenji2 with a pointwise regression head over genomic coverage tracks.

The binned positional axis is treated as the “token” axis: logits have shape (batch_size, num_bins, num_labels) where num_labels is the number of coverage tracks.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import BasenjiConfig, BasenjiForTokenPrediction
>>> config = BasenjiConfig(
...     sequence_length=256, stem_channels=8, conv_tower_channels=[8],
...     stem_pool_size=2, head_hidden_size=8, crop_bins=2, num_labels=4,
...     blocks={"num_blocks": 1, "kernel_size": 3, "bottleneck_size": 4},
... )
>>> model = BasenjiForTokenPrediction(config)
>>> input_ids = torch.randint(config.vocab_size, (1, 256))
>>> output = model(input_ids, labels=torch.randn(1, 60, 4))
>>> output["logits"].shape
torch.Size([1, 60, 4])
Source code in multimolecule/models/basenji/modeling_basenji.py
Python
class BasenjiForTokenPrediction(BasenjiPreTrainedModel):
    """
    Basenji2 with a pointwise regression head over genomic coverage tracks.

    The binned positional axis is treated as the "token" axis: logits have shape
    `(batch_size, num_bins, num_labels)` where `num_labels` is the number of coverage tracks.

    Examples:
        >>> import torch
        >>> from multimolecule import BasenjiConfig, BasenjiForTokenPrediction
        >>> config = BasenjiConfig(
        ...     sequence_length=256, stem_channels=8, conv_tower_channels=[8],
        ...     stem_pool_size=2, head_hidden_size=8, crop_bins=2, num_labels=4,
        ...     blocks={"num_blocks": 1, "kernel_size": 3, "bottleneck_size": 4},
        ... )
        >>> model = BasenjiForTokenPrediction(config)
        >>> input_ids = torch.randint(config.vocab_size, (1, 256))
        >>> output = model(input_ids, labels=torch.randn(1, 60, 4))
        >>> output["logits"].shape
        torch.Size([1, 60, 4])
    """

    def __init__(self, config: BasenjiConfig):
        super().__init__(config)
        self.model = BasenjiModel(config)
        token_head_config = HeadConfig(config.head) if config.head is not None else HeadConfig()
        if token_head_config.num_labels is None:
            token_head_config.num_labels = config.num_labels
        if token_head_config.hidden_size is None:
            token_head_config.hidden_size = config.head_hidden_size
        if token_head_config.problem_type is None:
            token_head_config.problem_type = "regression"
        self.token_head = TokenPredictionHead(config, token_head_config)
        self.head_config = self.token_head.config
        # Upstream applies `softplus` (absent from `ACT2FN`) to the linear track output; resolve it once and apply it
        # to obtain non-negative coverage in `postprocess` (and on the loss target), keeping the shared head unchanged.
        self.output_act: Callable[[Tensor], Tensor] | None = None
        if config.output_act == "softplus":
            self.output_act = F.softplus
        elif config.output_act is not None:
            self.output_act = ACT2FN[config.output_act]
        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        id2label = getattr(self.config, "id2label", None)
        if id2label is not None:
            labels = [str(id2label.get(index, f"track_{index}")) for index in range(self.config.num_labels)]
            if any(label != f"LABEL_{index}" for index, label in enumerate(labels)):
                return labels
        return [f"track_{index}" for index in range(self.config.num_labels)]

    def postprocess(self, outputs: TokenPredictorOutput | ModelOutput | Tensor) -> tuple[Tensor, list[str]]:
        r"""Return the non-negative per-track coverage prediction with track channel names."""
        logits = outputs if isinstance(outputs, Tensor) else outputs["logits"]
        coverage = self.output_act(logits) if self.output_act is not None else logits
        return coverage, self.output_channels

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[Tensor, ...] | TokenPredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        head_outputs = BaseModelOutput(last_hidden_state=outputs.last_hidden_state)
        # The binned axis has no special tokens; pass an all-ones mask so the shared head keeps every bin.
        bin_mask = outputs.last_hidden_state.new_ones(outputs.last_hidden_state.shape[:2], dtype=torch.long)
        output = self.token_head(head_outputs, bin_mask, None, None)
        logits = output.logits  # pre-activation; non-negative coverage = output_act(logits), exposed via postprocess

        loss = None
        if labels is not None:
            activated = self.output_act(logits) if self.output_act is not None else logits
            loss = self.token_head.criterion(activated, labels)

        return TokenPredictorOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
        )

postprocess

Python
postprocess(
    outputs: TokenPredictorOutput | ModelOutput | Tensor,
) -> tuple[Tensor, list[str]]

Return the non-negative per-track coverage prediction with track channel names.

Source code in multimolecule/models/basenji/modeling_basenji.py
Python
def postprocess(self, outputs: TokenPredictorOutput | ModelOutput | Tensor) -> tuple[Tensor, list[str]]:
    r"""Return the non-negative per-track coverage prediction with track channel names."""
    logits = outputs if isinstance(outputs, Tensor) else outputs["logits"]
    coverage = self.output_act(logits) if self.output_act is not None else logits
    return coverage, self.output_channels

BasenjiModel

Bases: BasenjiPreTrainedModel

The bare Basenji2 backbone. Consumes a long DNA window and returns binned hidden states.

The architecture faithfully reproduces the upstream Basenji2 trunk: a pre-activation convolution stem (GELU -> Conv -> BatchNorm -> MaxPool), a width-growing reducing tower, a dilated residual tower on a wide stream with a narrow bottleneck, a Cropping1D, and a final pointwise convolution block. The positional axis of the output is binned: a window of config.sequence_length base pairs is downsampled by the stem/tower and cropped, so last_hidden_state has shape (batch_size, num_bins, head_hidden_size).

Examples:

Python Console Session
>>> from multimolecule import BasenjiConfig, BasenjiModel
>>> config = BasenjiConfig(
...     sequence_length=256, stem_channels=8, conv_tower_channels=[8],
...     stem_pool_size=2, head_hidden_size=8, crop_bins=2,
...     blocks={"num_blocks": 1, "kernel_size": 3, "bottleneck_size": 4},
... )
>>> model = BasenjiModel(config)
>>> import torch
>>> input_ids = torch.randint(config.vocab_size, (1, 256))
>>> output = model(input_ids)
>>> output["last_hidden_state"].shape
torch.Size([1, 60, 8])
Source code in multimolecule/models/basenji/modeling_basenji.py
Python
class BasenjiModel(BasenjiPreTrainedModel):
    """
    The bare Basenji2 backbone. Consumes a long DNA window and returns binned hidden states.

    The architecture faithfully reproduces the upstream Basenji2 trunk: a pre-activation
    convolution stem (`GELU -> Conv -> BatchNorm -> MaxPool`), a width-growing reducing tower, a
    dilated residual tower on a wide stream with a narrow bottleneck, a `Cropping1D`, and a final
    pointwise convolution block. The positional axis of the output is *binned*: a window of
    `config.sequence_length` base pairs is downsampled by the stem/tower and cropped, so
    `last_hidden_state` has shape `(batch_size, num_bins, head_hidden_size)`.

    Examples:
        >>> from multimolecule import BasenjiConfig, BasenjiModel
        >>> config = BasenjiConfig(
        ...     sequence_length=256, stem_channels=8, conv_tower_channels=[8],
        ...     stem_pool_size=2, head_hidden_size=8, crop_bins=2,
        ...     blocks={"num_blocks": 1, "kernel_size": 3, "bottleneck_size": 4},
        ... )
        >>> model = BasenjiModel(config)
        >>> import torch
        >>> input_ids = torch.randint(config.vocab_size, (1, 256))
        >>> output = model(input_ids)
        >>> output["last_hidden_state"].shape
        torch.Size([1, 60, 8])
    """

    def __init__(self, config: BasenjiConfig):
        super().__init__(config)
        self.embeddings = BasenjiEmbedding(config)
        self.encoder = BasenjiEncoder(config)
        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        encoder_outputs = self.encoder(embedding_output, **kwargs)

        return BaseModelOutput(
            last_hidden_state=encoder_outputs.last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
        )

BasenjiPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/basenji/modeling_basenji.py
Python
class BasenjiPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BasenjiConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["BasenjiBlock"]

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, nn.Conv1d):
            init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, nn.Linear):
            init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)
        elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm, nn.GroupNorm)):
            init.ones_(module.weight)
            init.zeros_(module.bias)