跳转至

ProCapNet

ProCapNet

Base-resolution convolutional neural network for predicting PRO-cap transcription-initiation signal from DNA sequence.

Disclaimer

This is an UNOFFICIAL implementation of Dissecting the cis-regulatory syntax of transcription initiation with deep learning by Kelly Cochran et al.

The OFFICIAL repository of ProCapNet is at kundajelab/ProCapNet.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing ProCapNet did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

ProCapNet is a convolutional neural network (CNN) trained to predict base-resolution PRO-cap transcription-initiation signal from primary DNA sequence. Its architecture is largely adapted from Jacob Schreiber’s bpnet-lite and shares BPNet’s dilated-convolution backbone and profile/count factorization. The output is two-stranded (plus / minus strand), mappability-aware, and reconstructed by ProCapNetForProfilePrediction.postprocess. Please refer to the Training Details section for more information on the training process.

Model Specification

Input Length Profile Length Num Layers Hidden Size Num Parameters (M) FLOPs (G) MACs (G)
2114 1000 9 512 6.43 27.17 13.58

FLOPs and MACs are measured on the canonical 2114 bp ProCapNet input window.

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Transcription-Initiation Profile Prediction

You can use this model directly to predict PRO-cap transcription-initiation profiles of a DNA sequence:

Python
>>> from multimolecule import DnaTokenizer, ProCapNetForProfilePrediction

>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/procapnet")
>>> model = ProCapNetForProfilePrediction.from_pretrained("multimolecule/procapnet")
>>> output = model(**tokenizer(("ACGT" * 529)[:2114], return_tensors="pt"))

>>> output.keys()
odict_keys(['profile_logits', 'count_logits'])

>>> output["profile_logits"].shape
torch.Size([1, 1000, 2])

>>> output["count_logits"].shape
torch.Size([1, 1])

>>> track = model.postprocess(output)
>>> track.shape
torch.Size([1, 1000, 2])

The recombined track is the usable base-resolution prediction. The last dimension stacks the num_strands (plus, minus) PRO-cap signal predictions.

Interface

  • Input length: 2114 bp DNA window
  • Profile length: 1000 bp, two-stranded (plus / minus)
  • Output: factorized (profile_logits, count_logits); recombine the base-resolution PRO-cap track via ProCapNetForProfilePrediction.postprocess

Training Details

ProCapNet was trained to predict the base-resolution, two-stranded PRO-cap transcription-initiation signal in human cell lines. The default model is the K562 model.

Training Data

The published ProCapNet models were trained on PRO-cap signal using ~2 kb genomic windows. The default K562 model was trained on K562 PRO-cap experiment ENCSR261KBX. Training and test regions, observed signal tracks, and contribution scores are distributed through the same ENCODE release.

Training Procedure

Pre-training

The model was trained with a composite loss: a (strand-merged) multinomial negative log-likelihood on the per-position, two-stranded profile shape plus a mean-squared-error regression on log(count + 1) total counts.

  • Optimizer: Adam
  • Training is mappability-aware

Citation

BibTeX
1
2
3
4
5
6
7
8
@article{cochran2024procapnet,
  author    = {Cochran, Kelly and Yin, Melody and Mantripragada, Anika and Schreiber, Jacob and Marinov, Georgi K. and Shah, Sagar R. and Yu, Haiyuan and Lis, John T. and Kundaje, Anshul},
  title     = {Dissecting the cis-regulatory syntax of transcription initiation with deep learning},
  journal   = {bioRxiv},
  year      = 2024,
  doi       = {10.1101/2024.05.28.596138},
  note      = {Preprint}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the ProCapNet paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.procapnet

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

ProCapNetConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a ProCapNetModel. It is used to instantiate a ProCapNet model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the published ProCapNet K562 PRO-cap architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

ProCapNet predicts the base-resolution PRO-cap transcription-initiation signal whose output is factorized into two terminal branches that share the dilated-convolution backbone:

  • a profile branch producing per-position, two-stranded multinomial logits of shape (batch_size, profile_length, num_strands);
  • a count branch producing a single strand-merged log-count scalar of shape (batch_size, 1).

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the ProCapNet model. Defines the number of one-hot input channels derived from input_ids. Defaults to 5 to match the MultiMolecule streamline DNA alphabet (ACGTN).

5

sequence_length

int

The canonical input DNA sequence length in base pairs. Defaults to 2114.

2114

profile_length

int

The centered output profile length in base pairs. Defaults to 1000.

1000

hidden_size

int

Number of channels in the convolutional backbone.

512

stem_kernel_size

int

Kernel size of the first (motif) convolution.

21

num_dilated_layers

int

Number of dilated residual convolution blocks following the stem.

8

dilated_kernel_size

int

Kernel size of each dilated residual convolution.

3

profile_kernel_size

int

Kernel size of the profile-branch convolution.

75

num_strands

int

Number of strands predicted per position (plus / minus). ProCapNet is a two-stranded model.

2

hidden_act

str

The non-linear activation function (function or string) in the backbone.

'relu'

count_loss_weight

float

The weight applied to the count regression loss when combining it with the profile multinomial loss.

1.0

head

HeadConfig | None

The configuration of the generic token prediction head. If not provided, it defaults to regression.

None

output_hidden_states

bool

Whether to output the backbone hidden states.

False

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import ProCapNetConfig, ProCapNetModel
>>> # Initializing a ProCapNet style configuration
>>> configuration = ProCapNetConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = ProCapNetModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/procapnet/configuration_procapnet.py
Python
class ProCapNetConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`ProCapNetModel`][multimolecule.models.ProCapNetModel]. It is used to instantiate a ProCapNet model according to
    the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will
    yield a similar configuration to that of the published ProCapNet
    [K562 PRO-cap](https://www.encodeproject.org/experiments/ENCSR261KBX/) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    ProCapNet predicts the base-resolution PRO-cap transcription-initiation signal whose output is factorized into two
    terminal branches that share the dilated-convolution backbone:

    - a *profile* branch producing per-position, two-stranded multinomial logits of shape
      `(batch_size, profile_length, num_strands)`;
    - a *count* branch producing a single strand-merged log-count scalar of shape `(batch_size, 1)`.

    Args:
        vocab_size:
            Vocabulary size of the ProCapNet model. Defines the number of one-hot input channels derived from
            `input_ids`. Defaults to 5 to match the MultiMolecule `streamline` DNA alphabet (`ACGTN`).
        sequence_length:
            The canonical input DNA sequence length in base pairs.
            Defaults to 2114.
        profile_length:
            The centered output profile length in base pairs.
            Defaults to 1000.
        hidden_size:
            Number of channels in the convolutional backbone.
        stem_kernel_size:
            Kernel size of the first (motif) convolution.
        num_dilated_layers:
            Number of dilated residual convolution blocks following the stem.
        dilated_kernel_size:
            Kernel size of each dilated residual convolution.
        profile_kernel_size:
            Kernel size of the profile-branch convolution.
        num_strands:
            Number of strands predicted per position (plus / minus). ProCapNet is a two-stranded model.
        hidden_act:
            The non-linear activation function (function or string) in the backbone.
        count_loss_weight:
            The weight applied to the count regression loss when combining it with the profile multinomial loss.
        head:
            The configuration of the generic token prediction head. If not provided, it defaults to regression.
        output_hidden_states:
            Whether to output the backbone hidden states.

    Examples:
        >>> from multimolecule import ProCapNetConfig, ProCapNetModel
        >>> # Initializing a ProCapNet style configuration
        >>> configuration = ProCapNetConfig()
        >>> # Initializing a model (with random weights) from the configuration
        >>> model = ProCapNetModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "procapnet"

    def __init__(
        self,
        vocab_size: int = 5,
        sequence_length: int = 2114,
        profile_length: int = 1000,
        hidden_size: int = 512,
        stem_kernel_size: int = 21,
        num_dilated_layers: int = 8,
        dilated_kernel_size: int = 3,
        profile_kernel_size: int = 75,
        num_strands: int = 2,
        hidden_act: str = "relu",
        count_loss_weight: float = 1.0,
        head: HeadConfig | None = None,
        output_hidden_states: bool = False,
        bos_token_id: int | None = None,
        eos_token_id: int | None = None,
        pad_token_id: int = 4,
        **kwargs,
    ):
        super().__init__(pad_token_id=pad_token_id, **kwargs)
        self.bos_token_id = bos_token_id  # type: ignore[assignment]
        self.eos_token_id = eos_token_id  # type: ignore[assignment]
        if num_dilated_layers < 1:
            raise ValueError(f"num_dilated_layers ({num_dilated_layers}) must be at least 1.")
        if sequence_length < profile_length + profile_kernel_size - 1:
            raise ValueError(
                "sequence_length must be at least profile_length + profile_kernel_size - 1 "
                f"({profile_length + profile_kernel_size - 1}), but got {sequence_length}."
            )
        if profile_length < 1:
            raise ValueError(f"profile_length ({profile_length}) must be at least 1.")
        if num_strands < 1:
            raise ValueError(f"num_strands ({num_strands}) must be at least 1.")
        self.vocab_size = vocab_size
        self.sequence_length = sequence_length
        self.profile_length = profile_length
        self.hidden_size = hidden_size
        self.stem_kernel_size = stem_kernel_size
        self.num_dilated_layers = num_dilated_layers
        self.dilated_kernel_size = dilated_kernel_size
        self.profile_kernel_size = profile_kernel_size
        self.num_strands = num_strands
        self.hidden_act = hidden_act
        self.count_loss_weight = count_loss_weight
        if head is None:
            head = HeadConfig(problem_type="regression")
        else:
            head = HeadConfig(head)
            if head.problem_type is None:
                head.problem_type = "regression"
        self.head = head
        self.output_hidden_states = output_hidden_states

    @property
    def num_labels(self) -> int:
        return self.num_strands

    @num_labels.setter
    def num_labels(self, value: int) -> None:
        # ``PretrainedConfig.__init__`` assigns ``num_labels``; ProCapNet derives it from
        # ``num_strands`` (the two-stranded profile), so the assignment is intentionally ignored.
        pass

ProCapNetForProfilePrediction

Bases: ProCapNetPreTrainedModel

ProCapNet with the factorized profile/count head for base-resolution PRO-cap signal prediction.

This is a token/positional-prediction model: it is registered with the token AutoModel family and predicts a per-position value for every input nucleotide. The single base-resolution PRO-cap transcription-initiation task is factorized into two terminal branches sharing the backbone:

  • profile_logits: per-position, two-stranded multinomial logits of shape (batch_size, profile_length, num_strands);
  • count_logits: a single strand-merged log-count scalar of shape (batch_size, 1).

Unlike single-stranded BPNet, the ProCapNet profile is a joint multinomial over both strands and all positions (the plus / minus strand share one count), so [postprocess][multimolecule.models. ProCapNetForProfilePrediction.postprocess] normalizes the profile over the strand-and-position axes jointly.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import ProCapNetConfig, ProCapNetForProfilePrediction, DnaTokenizer
>>> config = ProCapNetConfig()
>>> model = ProCapNetForProfilePrediction(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/procapnet")
>>> input = tokenizer(("ACGT" * 529)[:2114], return_tensors="pt")
>>> output = model(**input)
>>> output["profile_logits"].shape
torch.Size([1, 1000, 2])
>>> output["count_logits"].shape
torch.Size([1, 1])
>>> track = model.postprocess(output)
>>> track.shape
torch.Size([1, 1000, 2])
Source code in multimolecule/models/procapnet/modeling_procapnet.py
Python
class ProCapNetForProfilePrediction(ProCapNetPreTrainedModel):
    """
    ProCapNet with the factorized profile/count head for base-resolution PRO-cap signal prediction.

    This is a token/positional-prediction model: it is registered with the token AutoModel family and predicts a
    per-position value for every input nucleotide. The single base-resolution PRO-cap transcription-initiation task is
    factorized into two terminal branches sharing the backbone:

    - `profile_logits`: per-position, two-stranded multinomial logits of shape
      `(batch_size, profile_length, num_strands)`;
    - `count_logits`: a single strand-merged log-count scalar of shape `(batch_size, 1)`.

    Unlike single-stranded BPNet, the ProCapNet profile is a joint multinomial over **both strands and all
    positions** (the plus / minus strand share one count), so [`postprocess`][multimolecule.models.
    ProCapNetForProfilePrediction.postprocess] normalizes the profile over the strand-and-position axes jointly.

    Examples:
        >>> import torch
        >>> from multimolecule import ProCapNetConfig, ProCapNetForProfilePrediction, DnaTokenizer
        >>> config = ProCapNetConfig()
        >>> model = ProCapNetForProfilePrediction(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/procapnet")
        >>> input = tokenizer(("ACGT" * 529)[:2114], return_tensors="pt")
        >>> output = model(**input)
        >>> output["profile_logits"].shape
        torch.Size([1, 1000, 2])
        >>> output["count_logits"].shape
        torch.Size([1, 1])
        >>> track = model.postprocess(output)
        >>> track.shape
        torch.Size([1, 1000, 2])
    """

    def __init__(self, config: ProCapNetConfig):
        super().__init__(config)
        self.model = ProCapNetModel(config)
        self.profile_count_head = ProCapNetProfileCountHead(config)
        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        if self.config.num_strands == 2:
            return ["plus", "minus"]
        return [f"strand_{index}" for index in range(self.config.num_strands)]

    @merge_with_config_defaults
    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: dict[str, Tensor] | Tuple[Tensor, Tensor] | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> ProCapNetProfilePredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        head_output = self.profile_count_head(outputs.last_hidden_state, labels)

        return ProCapNetProfilePredictorOutput(
            loss=head_output.loss,
            profile_logits=head_output.profile_logits,
            count_logits=head_output.count_logits,
            hidden_states=outputs.hidden_states,
        )

    def postprocess(self, outputs: ProCapNetProfilePredictorOutput | ModelOutput) -> Tensor:
        r"""
        Recombine the factorized profile and count branches into the usable base-resolution track.

        ProCapNet does not predict the signal track directly; the profile branch predicts the *shape* and the count
        branch predicts the *total magnitude* (in log space). Because ProCapNet is two-stranded with a single
        strand-merged count, the profile is a joint multinomial over **both strands and all positions**. The usable
        prediction recombines them as `softmax(profile_logits, strands & positions) * exp(count_logits)`.

        Args:
            outputs: The output of
                [`ProCapNetForProfilePrediction`][multimolecule.models.ProCapNetForProfilePrediction].

        Returns:
            The predicted base-resolution track of shape `(batch_size, profile_length, num_strands)`.
        """
        profile_logits = outputs["profile_logits"]
        count_logits = outputs["count_logits"]
        batch_size = profile_logits.shape[0]
        profile = F.softmax(profile_logits.reshape(batch_size, -1), dim=-1).reshape(profile_logits.shape)
        return profile * torch.exp(count_logits).unsqueeze(1)

postprocess

Python
postprocess(
    outputs: ProCapNetProfilePredictorOutput | ModelOutput,
) -> Tensor

Recombine the factorized profile and count branches into the usable base-resolution track.

ProCapNet does not predict the signal track directly; the profile branch predicts the shape and the count branch predicts the total magnitude (in log space). Because ProCapNet is two-stranded with a single strand-merged count, the profile is a joint multinomial over both strands and all positions. The usable prediction recombines them as softmax(profile_logits, strands & positions) * exp(count_logits).

Parameters:

Name Type Description Default
outputs
ProCapNetProfilePredictorOutput | ModelOutput required

Returns:

Type Description
Tensor

The predicted base-resolution track of shape (batch_size, profile_length, num_strands).

Source code in multimolecule/models/procapnet/modeling_procapnet.py
Python
def postprocess(self, outputs: ProCapNetProfilePredictorOutput | ModelOutput) -> Tensor:
    r"""
    Recombine the factorized profile and count branches into the usable base-resolution track.

    ProCapNet does not predict the signal track directly; the profile branch predicts the *shape* and the count
    branch predicts the *total magnitude* (in log space). Because ProCapNet is two-stranded with a single
    strand-merged count, the profile is a joint multinomial over **both strands and all positions**. The usable
    prediction recombines them as `softmax(profile_logits, strands & positions) * exp(count_logits)`.

    Args:
        outputs: The output of
            [`ProCapNetForProfilePrediction`][multimolecule.models.ProCapNetForProfilePrediction].

    Returns:
        The predicted base-resolution track of shape `(batch_size, profile_length, num_strands)`.
    """
    profile_logits = outputs["profile_logits"]
    count_logits = outputs["count_logits"]
    batch_size = profile_logits.shape[0]
    profile = F.softmax(profile_logits.reshape(batch_size, -1), dim=-1).reshape(profile_logits.shape)
    return profile * torch.exp(count_logits).unsqueeze(1)

ProCapNetForTokenPrediction

Bases: ProCapNetPreTrainedModel

ProCapNet backbone with a randomly initialized generic token-prediction head.

This class is intended for downstream fine-tuning from the ProCapNet backbone. It returns the standard [TokenPredictorOutput][multimolecule.models.TokenPredictorOutput] with a single logits field, unlike ProCapNetForProfilePrediction, which exposes the published factorized profile_logits / count_logits task head.

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> import torch
>>> from multimolecule import ProCapNetConfig, ProCapNetForTokenPrediction
>>> config = ProCapNetConfig()
>>> model = ProCapNetForTokenPrediction(config)
>>> input_ids = torch.randint(config.vocab_size, (1, 16))
>>> output = model(input_ids)
>>> output["logits"].shape
torch.Size([1, 16, 2])
Source code in multimolecule/models/procapnet/modeling_procapnet.py
Python
class ProCapNetForTokenPrediction(ProCapNetPreTrainedModel):
    """
    ProCapNet backbone with a randomly initialized generic token-prediction head.

    This class is intended for downstream fine-tuning from the ProCapNet backbone. It returns the standard
    [`TokenPredictorOutput`][multimolecule.models.TokenPredictorOutput] with a single `logits` field, unlike
    [`ProCapNetForProfilePrediction`][multimolecule.models.ProCapNetForProfilePrediction], which exposes the
    published factorized `profile_logits` / `count_logits` task head.

    Examples:
        >>> import torch
        >>> from multimolecule import ProCapNetConfig, ProCapNetForTokenPrediction
        >>> config = ProCapNetConfig()
        >>> model = ProCapNetForTokenPrediction(config)
        >>> input_ids = torch.randint(config.vocab_size, (1, 16))
        >>> output = model(input_ids)
        >>> output["logits"].shape
        torch.Size([1, 16, 2])
    """

    def __init__(self, config: ProCapNetConfig):
        super().__init__(config)
        self.model = ProCapNetModel(config)
        self.token_head = TokenPredictionHead(config)
        self.head_config = self.token_head.config
        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Tuple[Tensor, ...] | TokenPredictorOutput:
        head_attention_mask = attention_mask
        if input_ids is None and inputs_embeds is not None and head_attention_mask is None:
            if isinstance(inputs_embeds, NestedTensor):
                head_attention_mask = inputs_embeds.mask
            else:
                head_attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.int, device=inputs_embeds.device)

        outputs = self.model(
            input_ids,
            attention_mask=head_attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        output = self.token_head(outputs, head_attention_mask, input_ids, labels)
        return TokenPredictorOutput(
            loss=output.loss,
            logits=output.logits,
            hidden_states=outputs.hidden_states,
        )

ProCapNetHeadOutput dataclass

Bases: ModelOutput

Output of the factorized ProCapNet profile/count head.

Parameters:

Name Type Description Default

profile_logits

`torch.FloatTensor` of shape `(batch_size, profile_length, num_strands)`

Per-position, two-stranded multinomial logits.

None

count_logits

`torch.FloatTensor` of shape `(batch_size, 1)`

Strand-merged log-count scalar.

None

loss

`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided

Composite multinomial-NLL + weighted count-MSE loss.

None
Source code in multimolecule/models/procapnet/modeling_procapnet.py
Python
@dataclass
class ProCapNetHeadOutput(ModelOutput):
    """
    Output of the factorized ProCapNet profile/count head.

    Args:
        profile_logits (`torch.FloatTensor` of shape `(batch_size, profile_length, num_strands)`):
            Per-position, two-stranded multinomial logits.
        count_logits (`torch.FloatTensor` of shape `(batch_size, 1)`):
            Strand-merged log-count scalar.
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Composite multinomial-NLL + weighted count-MSE loss.
    """

    profile_logits: torch.FloatTensor | None = None
    count_logits: torch.FloatTensor | None = None
    loss: torch.FloatTensor | None = None

ProCapNetModel

Bases: ProCapNetPreTrainedModel

The bare ProCapNet dilated-convolution backbone producing per-position features.

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> from multimolecule import ProCapNetConfig, ProCapNetModel, DnaTokenizer
>>> config = ProCapNetConfig()
>>> model = ProCapNetModel(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/procapnet")
>>> input = tokenizer(("ACGT" * 529)[:2114], return_tensors="pt")
>>> output = model(**input)
>>> output["last_hidden_state"].shape
torch.Size([1, 2114, 512])
Source code in multimolecule/models/procapnet/modeling_procapnet.py
Python
class ProCapNetModel(ProCapNetPreTrainedModel):
    """
    The bare ProCapNet dilated-convolution backbone producing per-position features.

    Examples:
        >>> from multimolecule import ProCapNetConfig, ProCapNetModel, DnaTokenizer
        >>> config = ProCapNetConfig()
        >>> model = ProCapNetModel(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/procapnet")
        >>> input = tokenizer(("ACGT" * 529)[:2114], return_tensors="pt")
        >>> output = model(**input)
        >>> output["last_hidden_state"].shape
        torch.Size([1, 2114, 512])
    """

    def __init__(self, config: ProCapNetConfig):
        super().__init__(config)
        self.gradient_checkpointing = False
        self.embeddings = ProCapNetEmbedding(config)
        self.encoder = ProCapNetEncoder(config)
        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> ProCapNetModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        encoder_output = self.encoder(embedding_output, **kwargs)
        last_hidden_state = encoder_output.last_hidden_state.transpose(1, 2)

        return ProCapNetModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=encoder_output.hidden_states,
        )

ProCapNetModelOutput dataclass

Bases: ModelOutput

Base class for outputs of the ProCapNet backbone.

Parameters:

Name Type Description Default

last_hidden_state

`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`

Per-position backbone features.

None
Source code in multimolecule/models/procapnet/modeling_procapnet.py
Python
@dataclass
class ProCapNetModelOutput(ModelOutput):
    """
    Base class for outputs of the ProCapNet backbone.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Per-position backbone features.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
            when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the stem output plus one per dilated layer) of shape `(batch_size,
            sequence_length, hidden_size)`.
    """

    last_hidden_state: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None

ProCapNetPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/procapnet/modeling_procapnet.py
Python
class ProCapNetPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = ProCapNetConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["ProCapNetLayer"]

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
            init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, nn.Linear):
            init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)

ProCapNetProfilePredictorOutput dataclass

Bases: ModelOutput

Base class for outputs of ProCapNetForProfilePrediction.

The standard single-logits predictor dataclasses cannot express ProCapNet’s factorized output, so this model-local dataclass exposes the two terminal branches separately. Use [postprocess][multimolecule.models.ProCapNetForProfilePrediction.postprocess] to recombine them.

Parameters:

Name Type Description Default

loss

`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided

Composite multinomial-NLL (profile) + weighted count-MSE (count) loss.

None

profile_logits

`torch.FloatTensor` of shape `(batch_size, profile_length, num_strands)`

Per-position, two-stranded multinomial logits.

None

count_logits

`torch.FloatTensor` of shape `(batch_size, 1)`

Strand-merged log-count scalar.

None
Source code in multimolecule/models/procapnet/modeling_procapnet.py
Python
@dataclass
class ProCapNetProfilePredictorOutput(ModelOutput):
    """
    Base class for outputs of
    [`ProCapNetForProfilePrediction`][multimolecule.models.ProCapNetForProfilePrediction].

    The standard single-`logits` predictor dataclasses cannot express ProCapNet's factorized output, so this
    model-local dataclass exposes the two terminal branches separately. Use
    [`postprocess`][multimolecule.models.ProCapNetForProfilePrediction.postprocess] to recombine them.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Composite multinomial-NLL (profile) + weighted count-MSE (count) loss.
        profile_logits (`torch.FloatTensor` of shape `(batch_size, profile_length, num_strands)`):
            Per-position, two-stranded multinomial logits.
        count_logits (`torch.FloatTensor` of shape `(batch_size, 1)`):
            Strand-merged log-count scalar.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
            when `config.output_hidden_states=True`):
            Tuple of backbone hidden states of shape `(batch_size, sequence_length, hidden_size)`.
    """

    loss: torch.FloatTensor | None = None
    profile_logits: torch.FloatTensor | None = None
    count_logits: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None