Skip to content

Basenji

Deep convolutional neural network for predicting genomic coverage tracks across chromosomes.

Disclaimer

This is an UNOFFICIAL implementation of Sequential regulatory activity prediction across chromosomes with deep convolutional and recurrent neural networks by David R. Kelley, Yakir A. Reshef et al.

The OFFICIAL repository of Basenji is at calico/basenji.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing Basenji did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

Basenji is a deep convolutional neural network trained to predict genomic regulatory activity from long DNA sequences. It consumes a long DNA window (~131 kb), passes it through a convolution + pooling stem that downsamples the sequence, and then through a tower of dilated residual convolutional blocks that expand the receptive field. A pointwise output head predicts a vector of genomic coverage tracks for each output bin. Because the stem downsamples the input, the prediction is binned: the output has shape (batch_size, num_bins, num_tracks) where each bin summarizes 128 bp of sequence and num_tracks is the number of genomic coverage experiments.

Model Specification

Input Length Bin Size Output Bins Hidden Size Dilated Blocks Num Labels
131,072 128 896 768 11 5,313

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Genomic Coverage Prediction

You can use this model to predict binned genomic coverage tracks from a DNA sequence:

Python
>>> import torch
>>> from multimolecule import DnaTokenizer, BasenjiConfig, BasenjiForTokenPrediction

>>> config = BasenjiConfig(
...     sequence_length=256, stem_channels=8, conv_tower_channels=[8],
...     stem_pool_size=2, head_hidden_size=8, crop_bins=2, num_labels=4,
...     blocks={"num_blocks": 1, "kernel_size": 3, "bottleneck_size": 4},
... )
>>> model = BasenjiForTokenPrediction(config)
>>> output = model(torch.randint(config.vocab_size, (1, 256)))
>>> output.logits.shape
torch.Size([1, 60, 4])

The binned positional axis is treated as the “token” axis: each output position corresponds to one genomic bin rather than a single nucleotide.

Interface

  • Input length: fixed 131,072 bp DNA window
  • Output binning: 128 bp per output bin; 896 output bins per window (after Cropping1D(64) on each side)
  • Output: (batch_size, num_bins, num_tracks); num_tracks defaults to 5,313 human coverage experiments

Training Details

Basenji was trained to predict genomic coverage tracks (DNase-seq, ATAC-seq, ChIP-seq and CAGE) from the human and mouse reference genomes.

Training Data

The model was trained on a large compendium of functional genomics experiments aligned to the human (hg38) and mouse (mm10) reference genomes. The genome was divided into overlapping windows; for each window the per-128-bp coverage of every experiment served as the regression target.

Training Procedure

Pre-training

The model was trained to minimize a Poisson regression loss between predicted and observed coverage.

Citation

BibTeX
@article{kelley2018sequential,
  author    = {Kelley, David R. and Reshef, Yakir A. and Bileschi, Maxwell and Belanger, David and McLean, Cory Y. and Snoek, Jasper},
  title     = {Sequential regulatory activity prediction across chromosomes with deep convolutional and recurrent neural networks},
  journal   = {Genome Research},
  year      = 2018,
  volume    = 28,
  number    = 5,
  pages     = {739--750},
  doi       = {10.1101/gr.227819.117},
  publisher = {Cold Spring Harbor Laboratory}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the Basenji paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.basenji

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

BasenjiBlockConfig

Bases: FlatDict

Configuration for the dilated residual tower of the Basenji2 trunk.

Basenji2 stacks num_blocks dilated residual units. Each unit runs on a hidden_size-channel residual stream and internally bottlenecks to bottleneck_size channels for the dilated convolution before projecting back. The dilation factor starts at dilation and is multiplied by dilation_rate after every block (rounded to the nearest integer when round_dilation is set), which is how Basenji2 reaches the receptive field needed for ~131 kb input windows.

Parameters:

Name Type Description Default

num_blocks

Number of dilated residual blocks in the tower.

required

kernel_size

Kernel size of the dilated (bottleneck) convolution.

required

bottleneck_size

Channel count of the dilated convolution bottleneck.

required

dilation

Dilation factor of the first block.

required

dilation_rate

Multiplicative factor applied to the dilation after each block.

required

round_dilation

Whether to round the running dilation to the nearest integer after each multiply (upstream Basenji2 uses round=true).

required

dropout

Dropout probability applied to the projected (return) convolution of every block.

required
Source code in multimolecule/models/basenji/configuration_basenji.py
Python
class BasenjiBlockConfig(FlatDict):
    r"""
    Configuration for the dilated residual tower of the Basenji2 trunk.

    Basenji2 stacks `num_blocks` dilated residual units. Each unit runs on a `hidden_size`-channel
    residual stream and internally bottlenecks to `bottleneck_size` channels for the dilated
    convolution before projecting back. The dilation factor starts at `dilation` and is multiplied
    by `dilation_rate` after every block (rounded to the nearest integer when `round_dilation` is
    set), which is how Basenji2 reaches the receptive field needed for ~131 kb input windows.

    Args:
        num_blocks:
            Number of dilated residual blocks in the tower.
        kernel_size:
            Kernel size of the dilated (bottleneck) convolution.
        bottleneck_size:
            Channel count of the dilated convolution bottleneck.
        dilation:
            Dilation factor of the first block.
        dilation_rate:
            Multiplicative factor applied to the dilation after each block.
        round_dilation:
            Whether to round the running dilation to the nearest integer after each multiply
            (upstream Basenji2 uses `round=true`).
        dropout:
            Dropout probability applied to the projected (return) convolution of every block.
    """

    num_blocks: int = 11
    kernel_size: int = 3
    bottleneck_size: int = 384
    dilation: int = 1
    dilation_rate: float = 1.5
    round_dilation: bool = True
    dropout: float = 0.3

BasenjiConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a BasenjiModel. It is used to instantiate a Basenji model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a configuration that faithfully reproduces the upstream Basenji2 human graph (calico/basenji, manuscripts/cross2020/params_human.json).

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Basenji2 predicts genomic coverage tracks at a binned resolution. A long DNA window of sequence_length base pairs is downsampled by the convolution + pooling stem and tower, then cropped by crop_bins bins on each side, so the output has shape (batch_size, num_bins, num_labels) where num_labels is the number of coverage tracks.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the Basenji model. Defines the number of input feature channels derived from the MultiMolecule DNA token order. Defaults to 5 (A, C, G, T, N).

5

sequence_length

int

The length, in base pairs, of the input DNA window. Defaults to 131072 (~131 kb).

131072

stem_channels

int

Number of channels produced by the first (stem) convolution.

288

stem_kernel_size

int

Kernel size of the first (stem) convolution.

15

stem_pool_size

int

Pooling size applied after every convolution block in the stem and tower.

2

conv_tower_channels

list[int] | None

Explicit per-stage output channel schedule of the reducing convolution tower. Basenji2 grows the width as 339, 399, 470, 554, 652, 768; the tower length is len(conv_tower_channels) and each stage halves the resolution.

None

conv_tower_kernel_size

int

Kernel size used by every convolution in the reducing tower.

5

blocks

BasenjiBlockConfig | None

Configuration of the dilated residual tower. A single [BasenjiBlockConfig].

None

crop_bins

int

Number of bins trimmed from each side of the binned axis after the dilated tower (upstream Cropping1D).

64

head_hidden_size

int

Channel count of the final pointwise convolution block feeding the track head.

1536

hidden_act

str

The non-linear activation used throughout the network. Basenji2 uses the tanh-approximation GELU (gelu_new).

'gelu_new'

head_act

str

The activation applied to the final track projection. Basenji2 uses softplus.

'softplus'

hidden_dropout

float

Dropout probability of the final pointwise convolution block.

0.05

batch_norm_eps

float

The epsilon used by the batch normalization layers.

0.001

batch_norm_momentum

float

The momentum used by the batch normalization layers (PyTorch convention; upstream Keras momentum 0.9 corresponds to PyTorch momentum 0.1).

0.1

num_labels

int

Number of genomic coverage tracks predicted per bin. Defaults to 5313 (the human track set released with Basenji2).

5313

head

HeadConfig | None

The configuration of the binned track prediction head. Defaults to a regression head (problem_type="regression"), matching Basenji’s genomic coverage prediction task.

None

output_contexts

bool

Whether to output the context vectors for each tower block.

False

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import BasenjiConfig, BasenjiModel
>>> # Initializing a Basenji multimolecule/basenji style configuration
>>> configuration = BasenjiConfig()
>>> # Initializing a model (with random weights) from the multimolecule/basenji style configuration
>>> model = BasenjiModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/basenji/configuration_basenji.py
Python
class BasenjiConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`BasenjiModel`][multimolecule.models.BasenjiModel]. It is used to instantiate a Basenji model according to the
    specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a
    configuration that faithfully reproduces the upstream Basenji2 human graph
    ([calico/basenji](https://github.com/calico/basenji), `manuscripts/cross2020/params_human.json`).

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Basenji2 predicts genomic coverage tracks at a *binned* resolution. A long DNA window of
    `sequence_length` base pairs is downsampled by the convolution + pooling stem and tower, then
    cropped by `crop_bins` bins on each side, so the output has shape
    `(batch_size, num_bins, num_labels)` where `num_labels` is the number of coverage tracks.

    Args:
        vocab_size:
            Vocabulary size of the Basenji model. Defines the number of input feature channels
            derived from the MultiMolecule DNA token order.
            Defaults to 5 (`A`, `C`, `G`, `T`, `N`).
        sequence_length:
            The length, in base pairs, of the input DNA window.
            Defaults to 131072 (~131 kb).
        stem_channels:
            Number of channels produced by the first (stem) convolution.
        stem_kernel_size:
            Kernel size of the first (stem) convolution.
        stem_pool_size:
            Pooling size applied after every convolution block in the stem and tower.
        conv_tower_channels:
            Explicit per-stage output channel schedule of the reducing convolution tower. Basenji2
            grows the width as ``339, 399, 470, 554, 652, 768``; the tower length is
            ``len(conv_tower_channels)`` and each stage halves the resolution.
        conv_tower_kernel_size:
            Kernel size used by every convolution in the reducing tower.
        blocks:
            Configuration of the dilated residual tower. A single [`BasenjiBlockConfig`].
        crop_bins:
            Number of bins trimmed from *each* side of the binned axis after the dilated tower
            (upstream `Cropping1D`).
        head_hidden_size:
            Channel count of the final pointwise convolution block feeding the track head.
        hidden_act:
            The non-linear activation used throughout the network. Basenji2 uses the
            tanh-approximation GELU (`gelu_new`).
        head_act:
            The activation applied to the final track projection. Basenji2 uses `softplus`.
        hidden_dropout:
            Dropout probability of the final pointwise convolution block.
        batch_norm_eps:
            The epsilon used by the batch normalization layers.
        batch_norm_momentum:
            The momentum used by the batch normalization layers (PyTorch convention; upstream Keras
            momentum 0.9 corresponds to PyTorch momentum 0.1).
        num_labels:
            Number of genomic coverage tracks predicted per bin.
            Defaults to 5313 (the human track set released with Basenji2).
        head:
            The configuration of the binned track prediction head. Defaults to a regression head
            (`problem_type="regression"`), matching Basenji's genomic coverage prediction task.
        output_contexts:
            Whether to output the context vectors for each tower block.

    Examples:
        >>> from multimolecule import BasenjiConfig, BasenjiModel
        >>> # Initializing a Basenji multimolecule/basenji style configuration
        >>> configuration = BasenjiConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/basenji style configuration
        >>> model = BasenjiModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "basenji"

    def __init__(
        self,
        vocab_size: int = 5,
        sequence_length: int = 131072,
        stem_channels: int = 288,
        stem_kernel_size: int = 15,
        stem_pool_size: int = 2,
        conv_tower_channels: list[int] | None = None,
        conv_tower_kernel_size: int = 5,
        blocks: BasenjiBlockConfig | None = None,
        crop_bins: int = 64,
        head_hidden_size: int = 1536,
        hidden_act: str = "gelu_new",
        head_act: str = "softplus",
        hidden_dropout: float = 0.05,
        batch_norm_eps: float = 1e-3,
        batch_norm_momentum: float = 0.1,
        num_labels: int = 5313,
        head: HeadConfig | None = None,
        output_contexts: bool = False,
        **kwargs,
    ):
        # Basenji is a feature-channel DNA model: it consumes a raw one-hot DNA window with no
        # special tokens, and its output is binned coverage tracks. There is no BOS/EOS/MASK token
        # on either the input or the binned positional axis, so the shared TokenPredictionHead must
        # not trim "special" bins.
        kwargs.setdefault("bos_token_id", None)
        kwargs.setdefault("eos_token_id", None)
        kwargs.setdefault("mask_token_id", None)
        kwargs.setdefault("null_token_id", None)
        super().__init__(num_labels=num_labels, **kwargs)
        self.vocab_size = vocab_size
        self.sequence_length = sequence_length
        self.stem_channels = stem_channels
        self.stem_kernel_size = stem_kernel_size
        self.stem_pool_size = stem_pool_size
        if conv_tower_channels is None:
            conv_tower_channels = [339, 399, 470, 554, 652, 768]
        self.conv_tower_channels = list(conv_tower_channels)
        self.conv_tower_kernel_size = conv_tower_kernel_size
        if blocks is None:
            blocks = BasenjiBlockConfig()
        self.blocks = blocks if isinstance(blocks, BasenjiBlockConfig) else BasenjiBlockConfig(**dict(blocks))
        self.crop_bins = crop_bins
        self.head_hidden_size = head_hidden_size
        self.hidden_act = hidden_act
        self.head_act = head_act
        self.hidden_dropout = hidden_dropout
        self.batch_norm_eps = batch_norm_eps
        self.batch_norm_momentum = batch_norm_momentum
        if head is None:
            head = HeadConfig(problem_type="regression")
        else:
            head = HeadConfig(head)
            if head.problem_type is None:
                head.problem_type = "regression"
        self.head = head
        self.output_contexts = output_contexts

        if self.stem_pool_size < 1:
            raise ValueError(f"stem_pool_size must be >= 1, got {self.stem_pool_size}")
        if self.stem_channels < 1:
            raise ValueError(f"stem_channels must be >= 1, got {self.stem_channels}")
        if any(c < 1 for c in self.conv_tower_channels):
            raise ValueError(f"conv_tower_channels must be positive, got {self.conv_tower_channels}")
        if self.blocks.bottleneck_size < 1:
            raise ValueError(f"blocks.bottleneck_size must be >= 1, got {self.blocks.bottleneck_size}")
        if self.crop_bins < 0:
            raise ValueError(f"crop_bins must be >= 0, got {self.crop_bins}")
        if self.pool_factor <= 0:
            raise ValueError(f"pool_factor must be positive, got {self.pool_factor}")
        if self.sequence_length % self.pool_factor != 0:
            raise ValueError(
                f"sequence_length ({self.sequence_length}) must be divisible by the total pooling "
                f"factor ({self.pool_factor}) so the binned output is well defined."
            )
        if self.num_bins <= 0:
            raise ValueError(
                f"crop_bins ({self.crop_bins}) trims the entire binned axis "
                f"(pre-crop bins = {self.sequence_length // self.pool_factor}); reduce crop_bins."
            )

    @property
    def num_pool_layers(self) -> int:
        r"""Number of pooling stages: the stem block plus every reducing-tower stage."""
        return 1 + len(self.conv_tower_channels)

    @property
    def pool_factor(self) -> int:
        r"""Total downsampling factor applied by the stem and tower, i.e. base pairs per bin."""
        return self.stem_pool_size**self.num_pool_layers

    @property
    def hidden_size(self) -> int:
        r"""Channel count of the dilated residual stream."""
        return self.conv_tower_channels[-1] if self.conv_tower_channels else self.stem_channels

    @property
    def num_bins(self) -> int:
        r"""Number of output bins along the positional (token) axis, after cropping."""
        return self.sequence_length // self.pool_factor - 2 * self.crop_bins

num_pool_layers property

Python
num_pool_layers: int

Number of pooling stages: the stem block plus every reducing-tower stage.

pool_factor property

Python
pool_factor: int

Total downsampling factor applied by the stem and tower, i.e. base pairs per bin.

hidden_size property

Python
hidden_size: int

Channel count of the dilated residual stream.

num_bins property

Python
num_bins: int

Number of output bins along the positional (token) axis, after cropping.

BasenjiForTokenPrediction

Bases: BasenjiPreTrainedModel

Basenji2 with a pointwise regression head over genomic coverage tracks.

The binned positional axis is treated as the “token” axis: logits have shape (batch_size, num_bins, num_labels) where num_labels is the number of coverage tracks.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import BasenjiConfig, BasenjiForTokenPrediction
>>> config = BasenjiConfig(
...     sequence_length=256, stem_channels=8, conv_tower_channels=[8],
...     stem_pool_size=2, head_hidden_size=8, crop_bins=2, num_labels=4,
...     blocks={"num_blocks": 1, "kernel_size": 3, "bottleneck_size": 4},
... )
>>> model = BasenjiForTokenPrediction(config)
>>> input_ids = torch.randint(config.vocab_size, (1, 256))
>>> output = model(input_ids, labels=torch.randn(1, 60, 4))
>>> output["logits"].shape
torch.Size([1, 60, 4])
Source code in multimolecule/models/basenji/modeling_basenji.py
Python
class BasenjiForTokenPrediction(BasenjiPreTrainedModel):
    """
    Basenji2 with a pointwise regression head over genomic coverage tracks.

    The binned positional axis is treated as the "token" axis: logits have shape
    `(batch_size, num_bins, num_labels)` where `num_labels` is the number of coverage tracks.

    Examples:
        >>> import torch
        >>> from multimolecule import BasenjiConfig, BasenjiForTokenPrediction
        >>> config = BasenjiConfig(
        ...     sequence_length=256, stem_channels=8, conv_tower_channels=[8],
        ...     stem_pool_size=2, head_hidden_size=8, crop_bins=2, num_labels=4,
        ...     blocks={"num_blocks": 1, "kernel_size": 3, "bottleneck_size": 4},
        ... )
        >>> model = BasenjiForTokenPrediction(config)
        >>> input_ids = torch.randint(config.vocab_size, (1, 256))
        >>> output = model(input_ids, labels=torch.randn(1, 60, 4))
        >>> output["logits"].shape
        torch.Size([1, 60, 4])
    """

    def __init__(self, config: BasenjiConfig):
        super().__init__(config)
        self.model = BasenjiModel(config)
        # The shared TokenPredictionHead is the upstream `Dense(head_hidden_size -> num_labels)`
        # final layer: an identity transform with a biased linear decoder. Upstream applies a
        # `softplus` activation on the Dense output; `softplus` is not part of `ACT2FN`, so it is
        # applied explicitly in `forward` below (the model's output transform) to keep parity
        # while reusing the shared head unchanged.
        token_head_config = HeadConfig(config.head) if config.head is not None else HeadConfig()
        if token_head_config.num_labels is None:
            token_head_config.num_labels = config.num_labels
        if token_head_config.hidden_size is None:
            token_head_config.hidden_size = config.head_hidden_size
        if token_head_config.problem_type is None:
            token_head_config.problem_type = "regression"
        if token_head_config.transform is None:
            token_head_config.transform = None
        if token_head_config.act is None:
            token_head_config.act = None
        self.token_head = TokenPredictionHead(config, token_head_config)
        self.head_config = self.token_head.config
        self.head_act = config.head_act
        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        id2label = getattr(self.config, "id2label", None)
        if id2label is not None:
            labels = [str(id2label.get(index, f"track_{index}")) for index in range(self.config.num_labels)]
            if any(label != f"LABEL_{index}" for index, label in enumerate(labels)):
                return labels
        return [f"track_{index}" for index in range(self.config.num_labels)]

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Tuple[Tensor, ...] | TokenPredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        head_outputs = BaseModelOutput(last_hidden_state=outputs.last_hidden_state)
        # The binned axis has no special tokens; pass an all-ones mask so the shared head keeps
        # every bin. The head computes the unactivated upstream `Dense` projection.
        bin_mask = outputs.last_hidden_state.new_ones(outputs.last_hidden_state.shape[:2], dtype=torch.long)
        output = self.token_head(head_outputs, bin_mask, None, None)

        if self.head_act is None:
            logits = output.logits
        elif self.head_act == "softplus":
            logits = F.softplus(output.logits)
        else:
            logits = ACT2FN[self.head_act](output.logits)

        loss = None
        if labels is not None:
            loss = self.token_head.criterion(logits, labels)

        return TokenPredictorOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
        )

BasenjiModel

Bases: BasenjiPreTrainedModel

The bare Basenji2 backbone. Consumes a long DNA window and returns binned hidden states.

The architecture faithfully reproduces the upstream Basenji2 trunk: a pre-activation convolution stem (GELU -> Conv -> BatchNorm -> MaxPool), a width-growing reducing tower, a dilated residual tower on a wide stream with a narrow bottleneck, a Cropping1D, and a final pointwise convolution block. The positional axis of the output is binned: a window of config.sequence_length base pairs is downsampled by the stem/tower and cropped, so last_hidden_state has shape (batch_size, num_bins, head_hidden_size).

Examples:

Python Console Session
>>> from multimolecule import BasenjiConfig, BasenjiModel
>>> config = BasenjiConfig(
...     sequence_length=256, stem_channels=8, conv_tower_channels=[8],
...     stem_pool_size=2, head_hidden_size=8, crop_bins=2,
...     blocks={"num_blocks": 1, "kernel_size": 3, "bottleneck_size": 4},
... )
>>> model = BasenjiModel(config)
>>> import torch
>>> input_ids = torch.randint(config.vocab_size, (1, 256))
>>> output = model(input_ids)
>>> output["last_hidden_state"].shape
torch.Size([1, 60, 8])
Source code in multimolecule/models/basenji/modeling_basenji.py
Python
class BasenjiModel(BasenjiPreTrainedModel):
    """
    The bare Basenji2 backbone. Consumes a long DNA window and returns binned hidden states.

    The architecture faithfully reproduces the upstream Basenji2 trunk: a pre-activation
    convolution stem (`GELU -> Conv -> BatchNorm -> MaxPool`), a width-growing reducing tower, a
    dilated residual tower on a wide stream with a narrow bottleneck, a `Cropping1D`, and a final
    pointwise convolution block. The positional axis of the output is *binned*: a window of
    `config.sequence_length` base pairs is downsampled by the stem/tower and cropped, so
    `last_hidden_state` has shape `(batch_size, num_bins, head_hidden_size)`.

    Examples:
        >>> from multimolecule import BasenjiConfig, BasenjiModel
        >>> config = BasenjiConfig(
        ...     sequence_length=256, stem_channels=8, conv_tower_channels=[8],
        ...     stem_pool_size=2, head_hidden_size=8, crop_bins=2,
        ...     blocks={"num_blocks": 1, "kernel_size": 3, "bottleneck_size": 4},
        ... )
        >>> model = BasenjiModel(config)
        >>> import torch
        >>> input_ids = torch.randint(config.vocab_size, (1, 256))
        >>> output = model(input_ids)
        >>> output["last_hidden_state"].shape
        torch.Size([1, 60, 8])
    """

    def __init__(self, config: BasenjiConfig):
        super().__init__(config)
        self.gradient_checkpointing = False
        self.embeddings = BasenjiEmbedding(config)
        self.encoder = BasenjiEncoder(config)
        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        encoder_outputs = self.encoder(embedding_output, **kwargs)

        return BaseModelOutput(
            last_hidden_state=encoder_outputs.last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
        )

BasenjiPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/basenji/modeling_basenji.py
Python
class BasenjiPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BasenjiConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["BasenjiBlock"]

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, nn.Conv1d):
            init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, nn.Linear):
            init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)
        elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm, nn.GroupNorm)):
            init.ones_(module.weight)
            init.zeros_(module.bias)