跳转至

ChromBPNet

Bias-factorized, base-resolution convolutional neural network for predicting chromatin accessibility (ATAC-seq / DNase-seq) from DNA sequence.

Disclaimer

This is an UNOFFICIAL implementation of ChromBPNet: bias factorized, base-resolution deep learning models of chromatin accessibility reveal cis-regulatory sequence syntax, transcription factor footprints and regulatory variants by Anusri Pampari et al.

The OFFICIAL repository of ChromBPNet is at kundajelab/chrombpnet.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing ChromBPNet did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

ChromBPNet is a convolutional neural network (CNN) trained to predict base-resolution chromatin accessibility (ATAC-seq or DNase-seq) from primary DNA sequence with explicit enzyme-bias correction. It builds on the BPNet architecture and internally composes a bias sub-model with an accessibility sub-model. The composed output is factorized into profile and count branches, and the usable base-resolution prediction is reconstructed by ChromBPNetForProfilePrediction.postprocess. Please refer to the Training Details section for more information on the training process.

Model Specification

Input Length Profile Length Num Layers Hidden Size Bias Hidden Size Num Parameters (M)
2114 1000 9 + 5 512 128 5.5

The accessibility sub-model has 1 stem convolution + 8 dilated residual blocks (512 filters); the bias sub-model has 1 stem convolution + 4 dilated residual blocks (128 filters).

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Chromatin Accessibility Profile Prediction

You can use this model directly to predict base-resolution chromatin accessibility of a DNA sequence:

Python
>>> from multimolecule import DnaTokenizer, ChromBPNetForProfilePrediction

>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/chrombpnet")
>>> model = ChromBPNetForProfilePrediction.from_pretrained("multimolecule/chrombpnet")
>>> output = model(**tokenizer(("ACGT" * 529)[:2114], return_tensors="pt"))

>>> output.keys()
odict_keys(['profile_logits', 'count_logits'])

>>> output["profile_logits"].shape
torch.Size([1, 1000, 1])

>>> output["count_logits"].shape
torch.Size([1, 1])

>>> track = model.postprocess(output)
>>> track.shape
torch.Size([1, 1000, 1])

The recombined track is the usable, bias-corrected base-resolution accessibility prediction.

Interface

  • Input length: 2114 bp DNA window
  • Profile length: 1000 bp
  • Output: factorized (profile_logits, count_logits); recombine the bias-corrected base-resolution track via ChromBPNetForProfilePrediction.postprocess
  • Composition: profile logits added across bias + accessibility sub-models; counts combined via logsumexp

Training Details

ChromBPNet was trained to predict base-resolution chromatin accessibility profiles from ATAC-seq / DNase-seq with explicit enzyme-bias correction.

Training Data

The default ChromBPNet variant is the HEK293T GFP-control model from the RoboATAC ChromBPNet Models release (an automated ATAC-seq dataset from the Kundaje/Greenleaf labs). The accessibility and scaled-bias sub-models are composed internally.

Training Procedure

Pre-training

The model was trained with a composite loss: a multinomial negative log-likelihood on the per-position profile shape plus a mean-squared-error regression on the log total counts.

  • Optimizer: Adam

Citation

BibTeX
1
2
3
4
5
6
7
8
9
@article{pampari2024chrombpnet,
  author    = {Pampari, Anusri and Shcherbina, Anna and Kvon, Evgeny and Kosicki, Michael and Nair, Surag and Kundu, Soumya and Kathiria, Arwa S. and Risca, Viviana I. and Simola, Kristiina and Funk, Melissa J. and Furlong, Eileen E. M. and Pennacchio, Len A. and Greenleaf, William J. and Kundaje, Anshul},
  title     = {ChromBPNet: bias factorized, base-resolution deep learning models of chromatin accessibility reveal cis-regulatory sequence syntax, transcription factor footprints and regulatory variants},
  journal   = {bioRxiv},
  year      = 2024,
  publisher = {Cold Spring Harbor Laboratory},
  doi       = {10.1101/2024.12.25.630221},
  note      = {Preprint}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the ChromBPNet paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.chrombpnet

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

ChromBPNetConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a ChromBPNetModel. It is used to instantiate a ChromBPNet model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the ChromBPNet HEK293T-GFP architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

ChromBPNet predicts base-resolution chromatin accessibility (ATAC-seq / DNase-seq) with explicit enzyme-bias correction. It internally composes two BPNet-style dilated-convolution sub-models:

  • a bias sub-model that captures the Tn5/DNase enzyme cleavage bias on chromatin background;
  • an accessibility sub-model that learns the bias-corrected accessibility signal.

The final prediction is a single base-resolution task whose output is factorized into two terminal branches that share their respective backbones:

  • a profile branch producing per-position multinomial logits of shape (batch_size, profile_length, num_tasks * num_strands);
  • a count branch producing a scalar per task and strand of shape (batch_size, num_tasks * num_strands).

The bias and accessibility sub-models are composed internally: their profile logits are added, and their count logits are combined in log/exp space (logsumexp). They are not a user-facing split.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the ChromBPNet model. Defines the number of one-hot input channels derived from input_ids. Defaults to 5 to match the MultiMolecule streamline DNA alphabet (ACGTN).

5

sequence_length

int

The canonical input DNA sequence length in base pairs. Defaults to 2114.

2114

profile_length

int

The centered output profile length in base pairs. Defaults to 1000.

1000

hidden_size

int

Number of channels in the convolutional backbone of the accessibility sub-model.

512

bias_hidden_size

int

Number of channels in the convolutional backbone of the bias sub-model.

128

stem_kernel_size

int

Kernel size of the first (motif) convolution.

21

num_dilated_layers

int

Number of dilated residual convolution blocks following the stem in the accessibility sub-model.

8

bias_num_dilated_layers

int

Number of dilated residual convolution blocks following the stem in the bias sub-model.

4

dilated_kernel_size

int

Kernel size of each dilated residual convolution.

3

profile_kernel_size

int

Kernel size of the wide convolution in the profile branch.

75

num_tasks

int

Number of prediction tasks.

1

num_strands

int

Number of strands predicted per task. ChromBPNet ATAC/DNase predicts a single (unstranded) track.

1

hidden_act

str

The non-linear activation function (function or string) in the backbones.

'relu'

count_loss_weight

float

The weight applied to the count regression loss when combining it with the profile multinomial loss.

1.0

head

HeadConfig | None

The configuration of the generic token prediction head. If not provided, it defaults to regression.

None

output_hidden_states

bool

Whether to output the backbone hidden states.

False

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import ChromBPNetConfig, ChromBPNetModel
>>> # Initializing a ChromBPNet multimolecule/chrombpnet style configuration
>>> configuration = ChromBPNetConfig()
>>> # Initializing a model (with random weights) from the multimolecule/chrombpnet style configuration
>>> model = ChromBPNetModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/chrombpnet/configuration_chrombpnet.py
Python
class ChromBPNetConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`ChromBPNetModel`][multimolecule.models.ChromBPNetModel]. It is used to instantiate a ChromBPNet model according to
    the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will
    yield a similar configuration to that of the ChromBPNet
    [HEK293T-GFP](https://zenodo.org/records/16295014) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    ChromBPNet predicts base-resolution chromatin accessibility (ATAC-seq / DNase-seq) with explicit enzyme-bias
    correction. It internally composes two BPNet-style dilated-convolution sub-models:

    - a *bias* sub-model that captures the Tn5/DNase enzyme cleavage bias on chromatin background;
    - an *accessibility* sub-model that learns the bias-corrected accessibility signal.

    The final prediction is a single base-resolution task whose output is factorized into two terminal branches that
    share their respective backbones:

    - a *profile* branch producing per-position multinomial logits of shape
      `(batch_size, profile_length, num_tasks * num_strands)`;
    - a *count* branch producing a scalar per task and strand of shape `(batch_size, num_tasks * num_strands)`.

    The bias and accessibility sub-models are composed *internally*: their profile logits are added, and their count
    logits are combined in log/exp space (`logsumexp`). They are not a user-facing split.

    Args:
        vocab_size:
            Vocabulary size of the ChromBPNet model. Defines the number of one-hot input channels derived from
            `input_ids`. Defaults to 5 to match the MultiMolecule `streamline` DNA alphabet (`ACGTN`).
        sequence_length:
            The canonical input DNA sequence length in base pairs.
            Defaults to 2114.
        profile_length:
            The centered output profile length in base pairs.
            Defaults to 1000.
        hidden_size:
            Number of channels in the convolutional backbone of the accessibility sub-model.
        bias_hidden_size:
            Number of channels in the convolutional backbone of the bias sub-model.
        stem_kernel_size:
            Kernel size of the first (motif) convolution.
        num_dilated_layers:
            Number of dilated residual convolution blocks following the stem in the accessibility sub-model.
        bias_num_dilated_layers:
            Number of dilated residual convolution blocks following the stem in the bias sub-model.
        dilated_kernel_size:
            Kernel size of each dilated residual convolution.
        profile_kernel_size:
            Kernel size of the wide convolution in the profile branch.
        num_tasks:
            Number of prediction tasks.
        num_strands:
            Number of strands predicted per task. ChromBPNet ATAC/DNase predicts a single (unstranded) track.
        hidden_act:
            The non-linear activation function (function or string) in the backbones.
        count_loss_weight:
            The weight applied to the count regression loss when combining it with the profile multinomial loss.
        head:
            The configuration of the generic token prediction head. If not provided, it defaults to regression.
        output_hidden_states:
            Whether to output the backbone hidden states.

    Examples:
        >>> from multimolecule import ChromBPNetConfig, ChromBPNetModel
        >>> # Initializing a ChromBPNet multimolecule/chrombpnet style configuration
        >>> configuration = ChromBPNetConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/chrombpnet style configuration
        >>> model = ChromBPNetModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "chrombpnet"

    def __init__(
        self,
        vocab_size: int = 5,
        sequence_length: int = 2114,
        profile_length: int = 1000,
        hidden_size: int = 512,
        bias_hidden_size: int = 128,
        stem_kernel_size: int = 21,
        num_dilated_layers: int = 8,
        bias_num_dilated_layers: int = 4,
        dilated_kernel_size: int = 3,
        profile_kernel_size: int = 75,
        num_tasks: int = 1,
        num_strands: int = 1,
        hidden_act: str = "relu",
        count_loss_weight: float = 1.0,
        head: HeadConfig | None = None,
        output_hidden_states: bool = False,
        bos_token_id: int | None = None,
        eos_token_id: int | None = None,
        pad_token_id: int = 4,
        **kwargs,
    ):
        super().__init__(pad_token_id=pad_token_id, **kwargs)
        self.bos_token_id = bos_token_id  # type: ignore[assignment]
        self.eos_token_id = eos_token_id  # type: ignore[assignment]
        if num_dilated_layers < 1:
            raise ValueError(f"num_dilated_layers ({num_dilated_layers}) must be at least 1.")
        if bias_num_dilated_layers < 1:
            raise ValueError(f"bias_num_dilated_layers ({bias_num_dilated_layers}) must be at least 1.")
        required_profile_features = profile_length + profile_kernel_size - 1
        for name, layers in (
            ("num_dilated_layers", num_dilated_layers),
            ("bias_num_dilated_layers", bias_num_dilated_layers),
        ):
            receptive_field = (stem_kernel_size - 1) + sum(
                (dilated_kernel_size - 1) * 2 ** (i + 1) for i in range(layers)
            )
            valid_feature_length = sequence_length - receptive_field
            if valid_feature_length < required_profile_features:
                raise ValueError(
                    f"{name} leaves {valid_feature_length} valid positions, but the profile head needs at least "
                    f"{required_profile_features}."
                )
        if sequence_length < required_profile_features:
            raise ValueError(
                "sequence_length must be at least profile_length + profile_kernel_size - 1 "
                f"({required_profile_features}), but got {sequence_length}."
            )
        if profile_length < 1:
            raise ValueError(f"profile_length ({profile_length}) must be at least 1.")
        if num_tasks < 1:
            raise ValueError(f"num_tasks ({num_tasks}) must be at least 1.")
        if num_strands < 1:
            raise ValueError(f"num_strands ({num_strands}) must be at least 1.")
        self.vocab_size = vocab_size
        self.sequence_length = sequence_length
        self.profile_length = profile_length
        self.hidden_size = hidden_size
        self.bias_hidden_size = bias_hidden_size
        self.stem_kernel_size = stem_kernel_size
        self.num_dilated_layers = num_dilated_layers
        self.bias_num_dilated_layers = bias_num_dilated_layers
        self.dilated_kernel_size = dilated_kernel_size
        self.profile_kernel_size = profile_kernel_size
        self.num_tasks = num_tasks
        self.num_strands = num_strands
        self.hidden_act = hidden_act
        self.count_loss_weight = count_loss_weight
        if head is None:
            head = HeadConfig(problem_type="regression")
        else:
            head = HeadConfig(head)
            if head.problem_type is None:
                head.problem_type = "regression"
        self.head = head
        self.output_hidden_states = output_hidden_states

    @property
    def num_labels(self) -> int:
        return self.num_tasks * self.num_strands

    @num_labels.setter
    def num_labels(self, value: int) -> None:
        # ``PretrainedConfig.__init__`` assigns ``num_labels``; ChromBPNet derives it from
        # ``num_tasks * num_strands`` so the assignment is intentionally ignored.
        pass

ChromBPNetForProfilePrediction

Bases: ChromBPNetPreTrainedModel

ChromBPNet with the factorized profile/count head for base-resolution chromatin-accessibility prediction.

This is a token/positional-prediction model: it is registered with the token AutoModel family and predicts a per-position value for every input nucleotide. The single base-resolution task is factorized into two terminal branches:

  • profile_logits: per-position multinomial logits of shape (batch_size, profile_length, num_labels);
  • count_logits: a scalar per task and strand of shape (batch_size, num_labels),

where num_labels = num_tasks * num_strands. Use [postprocess][multimolecule.models. ChromBPNetForProfilePrediction.postprocess] to recombine them into the usable base-resolution track.

The enzyme-bias correction (the internal bias + accessibility composition) is performed inside ChromBPNetModel; the factorized head here mirrors BPNet and operates on the already bias-corrected, composed profile and count logits.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import ChromBPNetConfig, ChromBPNetForProfilePrediction, DnaTokenizer
>>> config = ChromBPNetConfig()
>>> model = ChromBPNetForProfilePrediction(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/chrombpnet")
>>> input = tokenizer(("ACGT" * 529)[:2114], return_tensors="pt")
>>> output = model(**input)
>>> output["profile_logits"].shape
torch.Size([1, 1000, 1])
>>> output["count_logits"].shape
torch.Size([1, 1])
>>> track = model.postprocess(output)
>>> track.shape
torch.Size([1, 1000, 1])
Source code in multimolecule/models/chrombpnet/modeling_chrombpnet.py
Python
class ChromBPNetForProfilePrediction(ChromBPNetPreTrainedModel):
    """
    ChromBPNet with the factorized profile/count head for base-resolution chromatin-accessibility prediction.

    This is a token/positional-prediction model: it is registered with the token AutoModel family and predicts a
    per-position value for every input nucleotide. The single base-resolution task is factorized into two terminal
    branches:

    - `profile_logits`: per-position multinomial logits of shape `(batch_size, profile_length, num_labels)`;
    - `count_logits`: a scalar per task and strand of shape `(batch_size, num_labels)`,

    where `num_labels = num_tasks * num_strands`. Use [`postprocess`][multimolecule.models.
    ChromBPNetForProfilePrediction.postprocess] to recombine them into the usable base-resolution track.

    The enzyme-bias correction (the internal bias + accessibility composition) is performed inside
    [`ChromBPNetModel`][multimolecule.models.ChromBPNetModel]; the factorized head here mirrors BPNet and operates on
    the already bias-corrected, composed profile and count logits.

    Examples:
        >>> import torch
        >>> from multimolecule import ChromBPNetConfig, ChromBPNetForProfilePrediction, DnaTokenizer
        >>> config = ChromBPNetConfig()
        >>> model = ChromBPNetForProfilePrediction(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/chrombpnet")
        >>> input = tokenizer(("ACGT" * 529)[:2114], return_tensors="pt")
        >>> output = model(**input)
        >>> output["profile_logits"].shape
        torch.Size([1, 1000, 1])
        >>> output["count_logits"].shape
        torch.Size([1, 1])
        >>> track = model.postprocess(output)
        >>> track.shape
        torch.Size([1, 1000, 1])
    """

    def __init__(self, config: ChromBPNetConfig):
        super().__init__(config)
        self.model = ChromBPNetModel(config)
        self.profile_count_head = ChromBPNetProfileCountHead(config)
        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        if self.config.num_tasks == 1 and self.config.num_strands == 1:
            return ["signal"]
        tasks = [f"task_{index}" for index in range(self.config.num_tasks)]
        if self.config.num_strands == 2:
            strands = ["plus", "minus"]
        else:
            strands = [f"strand_{index}" for index in range(self.config.num_strands)]
        return [f"{task}_{strand}" for task in tasks for strand in strands]

    @merge_with_config_defaults
    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: dict[str, Tensor] | Tuple[Tensor, Tensor] | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> ChromBPNetProfilePredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        head_output = self.profile_count_head(outputs.profile_logits, outputs.count_logits, labels)

        return ChromBPNetProfilePredictorOutput(
            loss=head_output.loss,
            profile_logits=head_output.profile_logits,
            count_logits=head_output.count_logits,
            hidden_states=outputs.hidden_states,
        )

    def postprocess(self, outputs: ChromBPNetProfilePredictorOutput | ModelOutput) -> Tensor:
        r"""
        Recombine the factorized profile and count branches into the usable base-resolution track.

        ChromBPNet does not predict the accessibility track directly; the profile branch predicts the *shape* (a
        per-position multinomial distribution) and the count branch predicts the *total magnitude* (in log space).
        The usable prediction recombines them as `softmax(profile_logits, positions) * expm1(count_logits)`.

        Args:
            outputs: The output of
                [`ChromBPNetForProfilePrediction`][multimolecule.models.ChromBPNetForProfilePrediction].

        Returns:
            The predicted base-resolution track of shape `(batch_size, profile_length, num_labels)`.
        """
        profile_logits = outputs["profile_logits"]
        count_logits = outputs["count_logits"]
        profile = F.softmax(profile_logits, dim=1)
        return profile * torch.expm1(count_logits).unsqueeze(1)

postprocess

Python
postprocess(
    outputs: ChromBPNetProfilePredictorOutput | ModelOutput,
) -> Tensor

Recombine the factorized profile and count branches into the usable base-resolution track.

ChromBPNet does not predict the accessibility track directly; the profile branch predicts the shape (a per-position multinomial distribution) and the count branch predicts the total magnitude (in log space). The usable prediction recombines them as softmax(profile_logits, positions) * expm1(count_logits).

Parameters:

Name Type Description Default
outputs
ChromBPNetProfilePredictorOutput | ModelOutput required

Returns:

Type Description
Tensor

The predicted base-resolution track of shape (batch_size, profile_length, num_labels).

Source code in multimolecule/models/chrombpnet/modeling_chrombpnet.py
Python
def postprocess(self, outputs: ChromBPNetProfilePredictorOutput | ModelOutput) -> Tensor:
    r"""
    Recombine the factorized profile and count branches into the usable base-resolution track.

    ChromBPNet does not predict the accessibility track directly; the profile branch predicts the *shape* (a
    per-position multinomial distribution) and the count branch predicts the *total magnitude* (in log space).
    The usable prediction recombines them as `softmax(profile_logits, positions) * expm1(count_logits)`.

    Args:
        outputs: The output of
            [`ChromBPNetForProfilePrediction`][multimolecule.models.ChromBPNetForProfilePrediction].

    Returns:
        The predicted base-resolution track of shape `(batch_size, profile_length, num_labels)`.
    """
    profile_logits = outputs["profile_logits"]
    count_logits = outputs["count_logits"]
    profile = F.softmax(profile_logits, dim=1)
    return profile * torch.expm1(count_logits).unsqueeze(1)

ChromBPNetForTokenPrediction

Bases: ChromBPNetPreTrainedModel

ChromBPNet accessibility backbone with a randomly initialized generic token-prediction head.

This class attaches the shared MultiMolecule token head to the accessibility sub-model representation and returns a standard single-logits output for downstream fine-tuning. The published ChromBPNet profile/count task remains exposed through ChromBPNetForProfilePrediction.

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> import torch
>>> from multimolecule import ChromBPNetConfig, ChromBPNetForTokenPrediction
>>> config = ChromBPNetConfig()
>>> model = ChromBPNetForTokenPrediction(config)
>>> input_ids = torch.randint(config.vocab_size, (1, config.sequence_length))
>>> output = model(input_ids)
>>> output["logits"].shape
torch.Size([1, 2114, 1])
Source code in multimolecule/models/chrombpnet/modeling_chrombpnet.py
Python
class ChromBPNetForTokenPrediction(ChromBPNetPreTrainedModel):
    """
    ChromBPNet accessibility backbone with a randomly initialized generic token-prediction head.

    This class attaches the shared MultiMolecule token head to the accessibility sub-model representation and returns a
    standard single-`logits` output for downstream fine-tuning. The published ChromBPNet profile/count task remains
    exposed through [`ChromBPNetForProfilePrediction`][multimolecule.models.ChromBPNetForProfilePrediction].

    Examples:
        >>> import torch
        >>> from multimolecule import ChromBPNetConfig, ChromBPNetForTokenPrediction
        >>> config = ChromBPNetConfig()
        >>> model = ChromBPNetForTokenPrediction(config)
        >>> input_ids = torch.randint(config.vocab_size, (1, config.sequence_length))
        >>> output = model(input_ids)
        >>> output["logits"].shape
        torch.Size([1, 2114, 1])
    """

    def __init__(self, config: ChromBPNetConfig):
        super().__init__(config)
        self.model = ChromBPNetModel(config)
        self.token_head = TokenPredictionHead(config)
        self.head_config = self.token_head.config
        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Tuple[Tensor, ...] | TokenPredictorOutput:
        head_attention_mask = attention_mask
        if input_ids is None and inputs_embeds is not None and head_attention_mask is None:
            if isinstance(inputs_embeds, NestedTensor):
                head_attention_mask = inputs_embeds.mask
            else:
                head_attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.int, device=inputs_embeds.device)

        outputs = self.model(
            input_ids,
            attention_mask=head_attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        output = self.token_head(outputs, head_attention_mask, input_ids, labels)
        return TokenPredictorOutput(
            loss=output.loss,
            logits=output.logits,
            hidden_states=outputs.hidden_states,
        )

ChromBPNetHeadOutput dataclass

Bases: ModelOutput

Output of the factorized ChromBPNet profile/count head.

Parameters:

Name Type Description Default

profile_logits

`torch.FloatTensor` of shape `(batch_size, profile_length, num_labels)`

Per-position multinomial logits.

None

count_logits

`torch.FloatTensor` of shape `(batch_size, num_labels)`

Per task/strand log-count scalars.

None

loss

`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided

Composite multinomial-NLL + weighted count-MSE loss.

None
Source code in multimolecule/models/chrombpnet/modeling_chrombpnet.py
Python
@dataclass
class ChromBPNetHeadOutput(ModelOutput):
    """
    Output of the factorized ChromBPNet profile/count head.

    Args:
        profile_logits (`torch.FloatTensor` of shape `(batch_size, profile_length, num_labels)`):
            Per-position multinomial logits.
        count_logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
            Per task/strand log-count scalars.
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Composite multinomial-NLL + weighted count-MSE loss.
    """

    profile_logits: torch.FloatTensor | None = None
    count_logits: torch.FloatTensor | None = None
    loss: torch.FloatTensor | None = None

ChromBPNetModel

Bases: ChromBPNetPreTrainedModel

The bare ChromBPNet model: an enzyme-bias sub-model composed with a bias-corrected accessibility sub-model.

ChromBPNet predicts base-resolution chromatin accessibility (ATAC-seq / DNase-seq) with explicit enzyme-bias correction. It internally owns two BPNet-style dilated-convolution sub-models and composes them so the model exposes a single clean factorized profile/count output:

  • the bias sub-model captures the Tn5/DNase enzyme cleavage bias on chromatin background;
  • the accessibility sub-model learns the bias-corrected accessibility signal.

The two sub-models are composed internally: their per-position profile logits are added, and their count logits are combined in log/exp space via logsumexp. The sub-model split is an implementation detail, not a user-facing API.

Examples:

Python Console Session
>>> from multimolecule import ChromBPNetConfig, ChromBPNetModel, DnaTokenizer
>>> config = ChromBPNetConfig()
>>> model = ChromBPNetModel(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/chrombpnet")
>>> input = tokenizer(("ACGT" * 529)[:2114], return_tensors="pt")
>>> output = model(**input)
>>> output["profile_logits"].shape
torch.Size([1, 1000, 1])
>>> output["count_logits"].shape
torch.Size([1, 1])
Source code in multimolecule/models/chrombpnet/modeling_chrombpnet.py
Python
class ChromBPNetModel(ChromBPNetPreTrainedModel):
    """
    The bare ChromBPNet model: an enzyme-bias sub-model composed with a bias-corrected accessibility sub-model.

    ChromBPNet predicts base-resolution chromatin accessibility (ATAC-seq / DNase-seq) with explicit enzyme-bias
    correction. It internally owns two BPNet-style dilated-convolution sub-models and composes them so the model
    exposes a single clean factorized profile/count output:

    - the *bias* sub-model captures the Tn5/DNase enzyme cleavage bias on chromatin background;
    - the *accessibility* sub-model learns the bias-corrected accessibility signal.

    The two sub-models are composed internally: their per-position profile logits are added, and their count logits are
    combined in log/exp space via `logsumexp`. The sub-model split is an implementation detail, not a user-facing API.

    Examples:
        >>> from multimolecule import ChromBPNetConfig, ChromBPNetModel, DnaTokenizer
        >>> config = ChromBPNetConfig()
        >>> model = ChromBPNetModel(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/chrombpnet")
        >>> input = tokenizer(("ACGT" * 529)[:2114], return_tensors="pt")
        >>> output = model(**input)
        >>> output["profile_logits"].shape
        torch.Size([1, 1000, 1])
        >>> output["count_logits"].shape
        torch.Size([1, 1])
    """

    def __init__(self, config: ChromBPNetConfig):
        super().__init__(config)
        self.gradient_checkpointing = False
        self.embeddings = ChromBPNetEmbedding(config)
        self.accessibility = ChromBPNetBranch(
            config, hidden_size=config.hidden_size, num_dilated_layers=config.num_dilated_layers
        )
        self.bias = ChromBPNetBranch(
            config, hidden_size=config.bias_hidden_size, num_dilated_layers=config.bias_num_dilated_layers
        )
        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> ChromBPNetModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        output_hidden_states = kwargs.get("output_hidden_states", self.config.output_hidden_states)

        accessibility_output = self.accessibility(embedding_output, output_hidden_states=output_hidden_states)
        bias_output = self.bias(embedding_output, output_hidden_states=output_hidden_states)

        # ChromBPNet composition (kundajelab/chrombpnet `chrombpnet_with_bias_model`):
        # profile logits are added; counts are combined in log/exp space via logsumexp.
        profile_logits = accessibility_output.profile_logits + bias_output.profile_logits
        count_logits = torch.logsumexp(
            torch.stack([accessibility_output.count_logits, bias_output.count_logits], dim=-1), dim=-1
        )

        hidden_states = None
        if output_hidden_states:
            hidden_states = (accessibility_output.hidden_states or ()) + (bias_output.hidden_states or ())

        return ChromBPNetModelOutput(
            last_hidden_state=accessibility_output.last_hidden_state,
            profile_logits=profile_logits,
            count_logits=count_logits,
            hidden_states=hidden_states,
        )

ChromBPNetModelOutput dataclass

Bases: ModelOutput

Base class for outputs of the ChromBPNet backbone.

The ChromBPNet backbone performs the bias + accessibility composition and exposes both the accessibility branch representation for generic fine-tuning and the composed factorized profile / count logits.

Parameters:

Name Type Description Default

last_hidden_state

`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`

Accessibility branch backbone representation.

None

profile_logits

`torch.FloatTensor` of shape `(batch_size, profile_length, num_labels)`

Composed (bias-corrected) per-position multinomial logits.

None

count_logits

`torch.FloatTensor` of shape `(batch_size, num_labels)`

Composed per task/strand log-count scalars.

None
Source code in multimolecule/models/chrombpnet/modeling_chrombpnet.py
Python
@dataclass
class ChromBPNetModelOutput(ModelOutput):
    """
    Base class for outputs of the ChromBPNet backbone.

    The ChromBPNet backbone performs the bias + accessibility composition and exposes both the accessibility branch
    representation for generic fine-tuning and the composed factorized profile / count logits.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Accessibility branch backbone representation.
        profile_logits (`torch.FloatTensor` of shape `(batch_size, profile_length, num_labels)`):
            Composed (bias-corrected) per-position multinomial logits.
        count_logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
            Composed per task/strand log-count scalars.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
            when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` backbone hidden states (accessibility sub-model first, then bias sub-model)
            of shape `(batch_size, sequence_length, hidden_size)`.
    """

    last_hidden_state: torch.FloatTensor | None = None
    profile_logits: torch.FloatTensor | None = None
    count_logits: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None

ChromBPNetPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/chrombpnet/modeling_chrombpnet.py
Python
class ChromBPNetPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = ChromBPNetConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["ChromBPNetLayer"]

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
            init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, nn.Linear):
            init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)

ChromBPNetProfilePredictorOutput dataclass

Bases: ModelOutput

Base class for outputs of ChromBPNetForProfilePrediction.

The standard single-logits predictor dataclasses cannot express ChromBPNet’s factorized output, so this model-local dataclass exposes the two terminal branches separately. Use [postprocess][multimolecule.models.ChromBPNetForProfilePrediction.postprocess] to recombine them.

Parameters:

Name Type Description Default

loss

`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided

Composite multinomial-NLL (profile) + weighted count-MSE (count) loss.

None

profile_logits

`torch.FloatTensor` of shape `(batch_size, profile_length, num_labels)`

Per-position multinomial logits, where num_labels = num_tasks * num_strands.

None

count_logits

`torch.FloatTensor` of shape `(batch_size, num_labels)`

Per task/strand log-count scalars.

None
Source code in multimolecule/models/chrombpnet/modeling_chrombpnet.py
Python
@dataclass
class ChromBPNetProfilePredictorOutput(ModelOutput):
    """
    Base class for outputs of
    [`ChromBPNetForProfilePrediction`][multimolecule.models.ChromBPNetForProfilePrediction].

    The standard single-`logits` predictor dataclasses cannot express ChromBPNet's factorized output, so this
    model-local dataclass exposes the two terminal branches separately. Use
    [`postprocess`][multimolecule.models.ChromBPNetForProfilePrediction.postprocess] to recombine them.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Composite multinomial-NLL (profile) + weighted count-MSE (count) loss.
        profile_logits (`torch.FloatTensor` of shape `(batch_size, profile_length, num_labels)`):
            Per-position multinomial logits, where `num_labels = num_tasks * num_strands`.
        count_logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
            Per task/strand log-count scalars.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
            when `config.output_hidden_states=True`):
            Tuple of backbone hidden states of shape `(batch_size, sequence_length, hidden_size)`.
    """

    loss: torch.FloatTensor | None = None
    profile_logits: torch.FloatTensor | None = None
    count_logits: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None