跳转至

BPNet

BPNet

Base-resolution convolutional neural network for predicting transcription-factor binding profiles from DNA sequence.

Disclaimer

This is an UNOFFICIAL implementation of Base-resolution models of transcription-factor binding reveal soft motif syntax by Žiga Avsec, Melanie Weilert et al.

The OFFICIAL repository of BPNet is at kundajelab/bpnet.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing BPNet did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

BPNet is a convolutional neural network (CNN) trained to predict base-resolution transcription-factor binding signal (ChIP-nexus) from primary DNA sequence. It uses a convolutional motif stem followed by a stack of dilated residual convolutions that aggregate ~1 kb of genomic context. The output is factorized into profile and count branches, and the usable base-resolution prediction is reconstructed by BPNetForProfilePrediction.postprocess. Please refer to the Training Details section for more information on the training process.

Model Specification

Num Layers Hidden Size Num Parameters (M) FLOPs (G) MACs (G)
10 64 0.13 0.24 0.12

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Transcription-Factor Binding Profile Prediction

You can use this model directly to predict transcription-factor binding profiles of a DNA sequence:

Python
>>> from multimolecule import DnaTokenizer, BPNetForProfilePrediction

>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/bpnet")
>>> model = BPNetForProfilePrediction.from_pretrained("multimolecule/bpnet")
>>> output = model(**tokenizer("ACGTNACGTN", return_tensors="pt"))

>>> output.keys()
odict_keys(['profile_logits', 'count_logits'])

>>> output["profile_logits"].shape
torch.Size([1, 10, 8])

>>> output["count_logits"].shape
torch.Size([1, 8])

>>> track = model.postprocess(output)
>>> track.shape
torch.Size([1, 10, 8])

The recombined track is the usable base-resolution prediction. The last dimension stacks num_tasks (Oct4, Sox2, Nanog, Klf4) by num_strands (forward, reverse).

Interface

  • Input length: 1000 bp DNA window
  • Output: factorized (profile_logits, count_logits); recombine the usable base-resolution track via BPNetForProfilePrediction.postprocess
  • Output shape: (batch_size, profile_length, num_tasks × num_strands); default Oct4 / Sox2 / Nanog / Klf4 × forward / reverse = 8 channels

Training Details

BPNet was trained to predict the base-resolution ChIP-nexus binding profiles of the pluripotency transcription factors Oct4, Sox2, Nanog and Klf4 in mouse embryonic stem cells.

Training Data

The published BPNet-OSKN model was trained on ChIP-nexus profiles for Oct4, Sox2, Nanog and Klf4, using 1 kb genomic windows centered on detected binding peaks. The training regions and trained Keras checkpoint are distributed as part of the BPNet manuscript data.

Training Procedure

Pre-training

The model was trained with a composite loss: a multinomial negative log-likelihood on the per-position profile shape plus a mean-squared-error regression on the log total counts.

  • Optimizer: Adam

Citation

BibTeX
@article{avsec2021baseresolution,
  author    = {Avsec, {\v{Z}}iga and Weilert, Melanie and Shrikumar, Avanti and Krueger, Sabrina and Alexandari, Amr and Dalal, Khyati and Fropf, Robin and McAnany, Charles and Gagneur, Julien and Kundaje, Anshul and Zeitlinger, Julia},
  title     = {Base-resolution models of transcription-factor binding reveal soft motif syntax},
  journal   = {Nature Genetics},
  volume    = 53,
  number    = 3,
  pages     = {354--366},
  year      = 2021,
  publisher = {Nature Publishing Group},
  doi       = {10.1038/s41588-021-00782-6}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the BPNet paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.bpnet

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

BPNetConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a BPNetModel. It is used to instantiate a BPNet model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the BPNet BPNet-OSKN architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

BPNet predicts a single base-resolution signal task whose output is factorized into two terminal branches that share the dilated-convolution backbone:

  • a profile branch producing per-position multinomial logits of shape (batch_size, sequence_length, num_tasks * num_strands);
  • a count branch producing a scalar per task and strand of shape (batch_size, num_tasks * num_strands).

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the BPNet model. Defines the number of one-hot input channels derived from input_ids. Defaults to 5 to match the MultiMolecule streamline DNA alphabet (ACGTN).

5

hidden_size

int

Number of channels in the convolutional backbone.

64

stem_kernel_size

int

Kernel size of the first (motif) convolution.

25

num_dilated_layers

int

Number of dilated residual convolution blocks following the stem.

9

dilated_kernel_size

int

Kernel size of each dilated residual convolution.

3

profile_kernel_size

int

Kernel size of the transposed convolution in the profile branch.

25

num_tasks

int

Number of prediction tasks (e.g. transcription factors).

4

num_strands

int

Number of strands predicted per task.

2

hidden_act

str

The non-linear activation function (function or string) in the backbone.

'relu'

count_loss_weight

float

The weight applied to the count regression loss when combining it with the profile multinomial loss.

1.0

head

HeadConfig | None

The configuration of the generic token prediction head. If not provided, it defaults to regression.

None

output_hidden_states

bool

Whether to output the backbone hidden states.

False

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import BPNetConfig, BPNetModel
>>> # Initializing a BPNet multimolecule/bpnet style configuration
>>> configuration = BPNetConfig()
>>> # Initializing a model (with random weights) from the multimolecule/bpnet style configuration
>>> model = BPNetModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/bpnet/configuration_bpnet.py
Python
class BPNetConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`BPNetModel`][multimolecule.models.BPNetModel]. It is used to instantiate a BPNet model according to the specified
    arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar
    configuration to that of the BPNet [BPNet-OSKN](https://zenodo.org/records/4294904) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    BPNet predicts a single base-resolution signal task whose output is factorized into two terminal branches that
    share the dilated-convolution backbone:

    - a *profile* branch producing per-position multinomial logits of shape
      `(batch_size, sequence_length, num_tasks * num_strands)`;
    - a *count* branch producing a scalar per task and strand of shape `(batch_size, num_tasks * num_strands)`.

    Args:
        vocab_size:
            Vocabulary size of the BPNet model. Defines the number of one-hot input channels derived from `input_ids`.
            Defaults to 5 to match the MultiMolecule `streamline` DNA alphabet (`ACGTN`).
        hidden_size:
            Number of channels in the convolutional backbone.
        stem_kernel_size:
            Kernel size of the first (motif) convolution.
        num_dilated_layers:
            Number of dilated residual convolution blocks following the stem.
        dilated_kernel_size:
            Kernel size of each dilated residual convolution.
        profile_kernel_size:
            Kernel size of the transposed convolution in the profile branch.
        num_tasks:
            Number of prediction tasks (e.g. transcription factors).
        num_strands:
            Number of strands predicted per task.
        hidden_act:
            The non-linear activation function (function or string) in the backbone.
        count_loss_weight:
            The weight applied to the count regression loss when combining it with the profile multinomial loss.
        head:
            The configuration of the generic token prediction head. If not provided, it defaults to regression.
        output_hidden_states:
            Whether to output the backbone hidden states.

    Examples:
        >>> from multimolecule import BPNetConfig, BPNetModel
        >>> # Initializing a BPNet multimolecule/bpnet style configuration
        >>> configuration = BPNetConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/bpnet style configuration
        >>> model = BPNetModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "bpnet"

    def __init__(
        self,
        vocab_size: int = 5,
        hidden_size: int = 64,
        stem_kernel_size: int = 25,
        num_dilated_layers: int = 9,
        dilated_kernel_size: int = 3,
        profile_kernel_size: int = 25,
        num_tasks: int = 4,
        num_strands: int = 2,
        hidden_act: str = "relu",
        count_loss_weight: float = 1.0,
        head: HeadConfig | None = None,
        output_hidden_states: bool = False,
        bos_token_id: int | None = None,
        eos_token_id: int | None = None,
        pad_token_id: int = 4,
        **kwargs,
    ):
        super().__init__(pad_token_id=pad_token_id, **kwargs)
        self.bos_token_id = bos_token_id  # type: ignore[assignment]
        self.eos_token_id = eos_token_id  # type: ignore[assignment]
        if num_dilated_layers < 1:
            raise ValueError(f"num_dilated_layers ({num_dilated_layers}) must be at least 1.")
        if num_tasks < 1:
            raise ValueError(f"num_tasks ({num_tasks}) must be at least 1.")
        if num_strands < 1:
            raise ValueError(f"num_strands ({num_strands}) must be at least 1.")
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.stem_kernel_size = stem_kernel_size
        self.num_dilated_layers = num_dilated_layers
        self.dilated_kernel_size = dilated_kernel_size
        self.profile_kernel_size = profile_kernel_size
        self.num_tasks = num_tasks
        self.num_strands = num_strands
        self.hidden_act = hidden_act
        self.count_loss_weight = count_loss_weight
        if head is None:
            head = HeadConfig(problem_type="regression")
        else:
            head = HeadConfig(head)
            if head.problem_type is None:
                head.problem_type = "regression"
        self.head = head
        self.output_hidden_states = output_hidden_states

    @property
    def num_labels(self) -> int:
        return self.num_tasks * self.num_strands

    @num_labels.setter
    def num_labels(self, value: int) -> None:
        # ``PretrainedConfig.__init__`` assigns ``num_labels``; BPNet derives it from
        # ``num_tasks * num_strands`` so the assignment is intentionally ignored.
        pass

BPNetForProfilePrediction

Bases: BPNetPreTrainedModel

BPNet with the factorized profile/count head for base-resolution signal prediction.

This is a token/positional-prediction model: it is registered with the token AutoModel family and predicts a per-position value for every input nucleotide. The single base-resolution task is factorized into two terminal branches sharing the backbone:

  • profile_logits: per-position multinomial logits of shape (batch_size, sequence_length, num_labels);
  • count_logits: a scalar per task and strand of shape (batch_size, num_labels),

where num_labels = num_tasks * num_strands. Use [postprocess][multimolecule.models.BPNetForProfilePrediction. postprocess] to recombine them into the usable base-resolution track.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import BPNetConfig, BPNetForProfilePrediction, DnaTokenizer
>>> config = BPNetConfig()
>>> model = BPNetForProfilePrediction(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/bpnet")
>>> input = tokenizer("ACGTNACGTN", return_tensors="pt")
>>> output = model(**input)
>>> output["profile_logits"].shape
torch.Size([1, 10, 8])
>>> output["count_logits"].shape
torch.Size([1, 8])
>>> track = model.postprocess(output)
>>> track.shape
torch.Size([1, 10, 8])
Source code in multimolecule/models/bpnet/modeling_bpnet.py
Python
class BPNetForProfilePrediction(BPNetPreTrainedModel):
    """
    BPNet with the factorized profile/count head for base-resolution signal prediction.

    This is a token/positional-prediction model: it is registered with the token AutoModel family and predicts a
    per-position value for every input nucleotide. The single base-resolution task is factorized into two terminal
    branches sharing the backbone:

    - `profile_logits`: per-position multinomial logits of shape `(batch_size, sequence_length, num_labels)`;
    - `count_logits`: a scalar per task and strand of shape `(batch_size, num_labels)`,

    where `num_labels = num_tasks * num_strands`. Use [`postprocess`][multimolecule.models.BPNetForProfilePrediction.
    postprocess] to recombine them into the usable base-resolution track.

    Examples:
        >>> import torch
        >>> from multimolecule import BPNetConfig, BPNetForProfilePrediction, DnaTokenizer
        >>> config = BPNetConfig()
        >>> model = BPNetForProfilePrediction(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/bpnet")
        >>> input = tokenizer("ACGTNACGTN", return_tensors="pt")
        >>> output = model(**input)
        >>> output["profile_logits"].shape
        torch.Size([1, 10, 8])
        >>> output["count_logits"].shape
        torch.Size([1, 8])
        >>> track = model.postprocess(output)
        >>> track.shape
        torch.Size([1, 10, 8])
    """

    def __init__(self, config: BPNetConfig):
        super().__init__(config)
        self.model = BPNetModel(config)
        self.profile_count_head = BPNetProfileCountHead(config)
        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        if self.config.num_tasks == 4:
            tasks = ["Oct4", "Sox2", "Nanog", "Klf4"]
        else:
            tasks = [f"task_{index}" for index in range(self.config.num_tasks)]
        if self.config.num_strands == 2:
            strands = ["plus", "minus"]
        else:
            strands = [f"strand_{index}" for index in range(self.config.num_strands)]
        return [f"{task}_{strand}" for task in tasks for strand in strands]

    @merge_with_config_defaults
    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: dict[str, Tensor] | Tuple[Tensor, Tensor] | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BPNetProfilePredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        head_output = self.profile_count_head(outputs.last_hidden_state, labels)

        return BPNetProfilePredictorOutput(
            loss=head_output.loss,
            profile_logits=head_output.profile_logits,
            count_logits=head_output.count_logits,
            hidden_states=outputs.hidden_states,
        )

    def postprocess(self, outputs: BPNetProfilePredictorOutput | ModelOutput) -> Tensor:
        r"""
        Recombine the factorized profile and count branches into the usable base-resolution track.

        BPNet does not predict the signal track directly; the profile branch predicts the *shape* (a per-position
        multinomial distribution) and the count branch predicts the *total magnitude* (in log space). The usable
        prediction recombines them as `softmax(profile_logits, positions) * exp(count_logits)`.

        Args:
            outputs: The output of [`BPNetForProfilePrediction`][multimolecule.models.BPNetForProfilePrediction].

        Returns:
            The predicted base-resolution track of shape `(batch_size, sequence_length, num_labels)`.
        """
        profile_logits = outputs["profile_logits"]
        count_logits = outputs["count_logits"]
        profile = F.softmax(profile_logits, dim=1)
        return profile * torch.exp(count_logits).unsqueeze(1)

postprocess

Python
postprocess(
    outputs: BPNetProfilePredictorOutput | ModelOutput,
) -> Tensor

Recombine the factorized profile and count branches into the usable base-resolution track.

BPNet does not predict the signal track directly; the profile branch predicts the shape (a per-position multinomial distribution) and the count branch predicts the total magnitude (in log space). The usable prediction recombines them as softmax(profile_logits, positions) * exp(count_logits).

Parameters:

Name Type Description Default
outputs
BPNetProfilePredictorOutput | ModelOutput

The output of BPNetForProfilePrediction.

required

Returns:

Type Description
Tensor

The predicted base-resolution track of shape (batch_size, sequence_length, num_labels).

Source code in multimolecule/models/bpnet/modeling_bpnet.py
Python
def postprocess(self, outputs: BPNetProfilePredictorOutput | ModelOutput) -> Tensor:
    r"""
    Recombine the factorized profile and count branches into the usable base-resolution track.

    BPNet does not predict the signal track directly; the profile branch predicts the *shape* (a per-position
    multinomial distribution) and the count branch predicts the *total magnitude* (in log space). The usable
    prediction recombines them as `softmax(profile_logits, positions) * exp(count_logits)`.

    Args:
        outputs: The output of [`BPNetForProfilePrediction`][multimolecule.models.BPNetForProfilePrediction].

    Returns:
        The predicted base-resolution track of shape `(batch_size, sequence_length, num_labels)`.
    """
    profile_logits = outputs["profile_logits"]
    count_logits = outputs["count_logits"]
    profile = F.softmax(profile_logits, dim=1)
    return profile * torch.exp(count_logits).unsqueeze(1)

BPNetForTokenPrediction

Bases: BPNetPreTrainedModel

BPNet backbone with a randomly initialized generic token-prediction head.

This class is intended for downstream fine-tuning from the BPNet backbone. It returns the standard [TokenPredictorOutput][multimolecule.models.TokenPredictorOutput] with a single logits field, unlike BPNetForProfilePrediction, which exposes the published factorized profile_logits / count_logits task head.

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> import torch
>>> from multimolecule import BPNetConfig, BPNetForTokenPrediction
>>> config = BPNetConfig()
>>> model = BPNetForTokenPrediction(config)
>>> input_ids = torch.randint(config.vocab_size, (1, 16))
>>> output = model(input_ids)
>>> output["logits"].shape
torch.Size([1, 16, 8])
Source code in multimolecule/models/bpnet/modeling_bpnet.py
Python
class BPNetForTokenPrediction(BPNetPreTrainedModel):
    """
    BPNet backbone with a randomly initialized generic token-prediction head.

    This class is intended for downstream fine-tuning from the BPNet backbone. It returns the standard
    [`TokenPredictorOutput`][multimolecule.models.TokenPredictorOutput] with a single `logits` field, unlike
    [`BPNetForProfilePrediction`][multimolecule.models.BPNetForProfilePrediction], which exposes the published
    factorized `profile_logits` / `count_logits` task head.

    Examples:
        >>> import torch
        >>> from multimolecule import BPNetConfig, BPNetForTokenPrediction
        >>> config = BPNetConfig()
        >>> model = BPNetForTokenPrediction(config)
        >>> input_ids = torch.randint(config.vocab_size, (1, 16))
        >>> output = model(input_ids)
        >>> output["logits"].shape
        torch.Size([1, 16, 8])
    """

    def __init__(self, config: BPNetConfig):
        super().__init__(config)
        self.model = BPNetModel(config)
        self.token_head = TokenPredictionHead(config)
        self.head_config = self.token_head.config
        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Tuple[Tensor, ...] | TokenPredictorOutput:
        head_attention_mask = attention_mask
        if input_ids is None and inputs_embeds is not None and head_attention_mask is None:
            if isinstance(inputs_embeds, NestedTensor):
                head_attention_mask = inputs_embeds.mask
            else:
                head_attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.int, device=inputs_embeds.device)

        outputs = self.model(
            input_ids,
            attention_mask=head_attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        output = self.token_head(outputs, head_attention_mask, input_ids, labels)
        return TokenPredictorOutput(
            loss=output.loss,
            logits=output.logits,
            hidden_states=outputs.hidden_states,
        )

BPNetHeadOutput dataclass

Bases: ModelOutput

Output of the factorized BPNet profile/count head.

Parameters:

Name Type Description Default

profile_logits

`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`

Per-position multinomial logits.

None

count_logits

`torch.FloatTensor` of shape `(batch_size, num_labels)`

Per task/strand log-count scalars.

None

loss

`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided

Composite multinomial-NLL + weighted count-MSE loss.

None
Source code in multimolecule/models/bpnet/modeling_bpnet.py
Python
@dataclass
class BPNetHeadOutput(ModelOutput):
    """
    Output of the factorized BPNet profile/count head.

    Args:
        profile_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`):
            Per-position multinomial logits.
        count_logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
            Per task/strand log-count scalars.
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Composite multinomial-NLL + weighted count-MSE loss.
    """

    profile_logits: torch.FloatTensor | None = None
    count_logits: torch.FloatTensor | None = None
    loss: torch.FloatTensor | None = None

BPNetModel

Bases: BPNetPreTrainedModel

The bare BPNet dilated-convolution backbone producing per-position features.

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> from multimolecule import BPNetConfig, BPNetModel, DnaTokenizer
>>> config = BPNetConfig()
>>> model = BPNetModel(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/bpnet")
>>> input = tokenizer("ACGTNACGTN", return_tensors="pt")
>>> output = model(**input)
>>> output["last_hidden_state"].shape
torch.Size([1, 10, 64])
Source code in multimolecule/models/bpnet/modeling_bpnet.py
Python
class BPNetModel(BPNetPreTrainedModel):
    """
    The bare BPNet dilated-convolution backbone producing per-position features.

    Examples:
        >>> from multimolecule import BPNetConfig, BPNetModel, DnaTokenizer
        >>> config = BPNetConfig()
        >>> model = BPNetModel(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/bpnet")
        >>> input = tokenizer("ACGTNACGTN", return_tensors="pt")
        >>> output = model(**input)
        >>> output["last_hidden_state"].shape
        torch.Size([1, 10, 64])
    """

    def __init__(self, config: BPNetConfig):
        super().__init__(config)
        self.gradient_checkpointing = False
        self.embeddings = BPNetEmbedding(config)
        self.encoder = BPNetEncoder(config)
        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BPNetModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        encoder_output = self.encoder(embedding_output, **kwargs)
        last_hidden_state = encoder_output.last_hidden_state.transpose(1, 2)

        return BPNetModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=encoder_output.hidden_states,
        )

BPNetModelOutput dataclass

Bases: ModelOutput

Base class for outputs of the BPNet backbone.

Parameters:

Name Type Description Default

last_hidden_state

`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`

Per-position backbone features.

None
Source code in multimolecule/models/bpnet/modeling_bpnet.py
Python
@dataclass
class BPNetModelOutput(ModelOutput):
    """
    Base class for outputs of the BPNet backbone.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Per-position backbone features.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
            when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the stem output plus one per dilated layer) of shape `(batch_size,
            sequence_length, hidden_size)`.
    """

    last_hidden_state: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None

BPNetPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/bpnet/modeling_bpnet.py
Python
class BPNetPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BPNetConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["BPNetLayer"]

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
            init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, nn.Linear):
            init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)

BPNetProfilePredictorOutput dataclass

Bases: ModelOutput

Base class for outputs of BPNetForProfilePrediction.

The standard single-logits predictor dataclasses cannot express BPNet’s factorized output, so this model-local dataclass exposes the two terminal branches separately. Use [postprocess][multimolecule.models.BPNetForProfilePrediction.postprocess] to recombine them.

Parameters:

Name Type Description Default

loss

`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided

Composite multinomial-NLL (profile) + weighted count-MSE (count) loss.

None

profile_logits

`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`

Per-position multinomial logits, where num_labels = num_tasks * num_strands.

None

count_logits

`torch.FloatTensor` of shape `(batch_size, num_labels)`

Per task/strand log-count scalars.

None
Source code in multimolecule/models/bpnet/modeling_bpnet.py
Python
@dataclass
class BPNetProfilePredictorOutput(ModelOutput):
    """
    Base class for outputs of [`BPNetForProfilePrediction`][multimolecule.models.BPNetForProfilePrediction].

    The standard single-`logits` predictor dataclasses cannot express BPNet's factorized output, so this model-local
    dataclass exposes the two terminal branches separately. Use
    [`postprocess`][multimolecule.models.BPNetForProfilePrediction.postprocess] to recombine them.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Composite multinomial-NLL (profile) + weighted count-MSE (count) loss.
        profile_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`):
            Per-position multinomial logits, where `num_labels = num_tasks * num_strands`.
        count_logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
            Per task/strand log-count scalars.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
            when `config.output_hidden_states=True`):
            Tuple of backbone hidden states of shape `(batch_size, sequence_length, hidden_size)`.
    """

    loss: torch.FloatTensor | None = None
    profile_logits: torch.FloatTensor | None = None
    count_logits: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None