Skip to content

Pangolin

Convolutional neural network for predicting tissue-specific splice site strength from pre-mRNA sequences.

Disclaimer

This is an UNOFFICIAL implementation of Predicting RNA splicing from DNA sequence using Pangolin by Tony Zeng et al.

The OFFICIAL repository of Pangolin is at tkzeng/Pangolin.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing Pangolin did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

Pangolin is a deep convolutional neural network (CNN) that predicts splice site strength from primary pre-mRNA sequence. It extends the dilated-residual SpliceAI architecture to predict tissue-specific splice site usage, and is trained on splicing measurements derived from RNA-seq data across multiple tissues. The network processes a one-hot encoded nucleotide sequence and, for each position, predicts a splice-site score and a splice-site usage score per tissue. Pangolin is typically used to estimate the effect of genetic variants on splicing by scoring reference and alternate sequences and taking the difference. Please refer to the Training Details section for more information on the training process.

The official release distributes tissue-specific ensembles. The canonical v2 model uses the three replicate networks for each of the four tissue groups (heart, liver, brain, and testis). Ensemble membership is an implementation detail and is not exposed in the downstream API.

Model Specification

Num Layers Hidden Size Num Parameters (M) FLOPs (G) MACs (G)
16 32 8.36 168.85 84.04

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

RNA Splicing Site Prediction

You can use this model directly to predict per-nucleotide tissue-specific splice-site score and usage channels for a pre-mRNA sequence:

Python
1
2
3
4
5
6
7
8
>>> from multimolecule import DnaTokenizer, PangolinModel

>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/pangolin")
>>> model = PangolinModel.from_pretrained("multimolecule/pangolin")
>>> output = model(tokenizer("AGCAGTCATTATGGCGAA", return_tensors="pt")["input_ids"])

>>> output.keys()
odict_keys(['last_hidden_state', 'logits'])

The logits tensor reproduces the original Pangolin output: for each of the four tissues, two splice-site score channels and one splice-site usage channel.

Downstream Use

Token Prediction

You can fine-tune Pangolin for per-nucleotide splice site strength regression with PangolinForTokenPrediction, which adds a shared token prediction head on top of the backbone.

Interface

  • Input length: variable pre-mRNA sequence
  • Padding: flanking context padded with N near transcript ends
  • Output: per-position tissue-specific channels — for each of 4 tissues, 2 splice-site score channels + 1 splice-site usage channel

Training Details

Pangolin was trained to predict tissue-specific splice site usage from primary pre-mRNA sequence.

Training Data

Pangolin was trained on splice site usage derived from RNA-seq data in heart, liver, brain, and testis tissues from human and three other species, using gene annotations from GENCODE. For each nucleotide whose splicing status was predicted, a sequence window centered on that nucleotide was used, with the flanking context padded with N (unknown nucleotide) when near transcript ends.

Training Procedure

Pre-training

The model was trained to minimize a combination of cross-entropy loss over splice-site classification and a regression loss over splice-site usage, comparing predictions against measurements derived from RNA-seq.

  • Optimizer: AdamW
  • Learning rate scheduler: Step decay

Citation

BibTeX
@article{zeng2022predicting,
  author    = {Zeng, Tony and Li, Yang I.},
  title     = {Predicting RNA splicing from DNA sequence using Pangolin},
  journal   = {Genome Biology},
  volume    = {23},
  number    = {1},
  pages     = {103},
  year      = {2022},
  doi       = {10.1186/s13059-022-02664-4},
  publisher = {BioMed Central}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the Pangolin paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.pangolin

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

PangolinConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a PangolinModel. It is used to instantiate a Pangolin model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Pangolin tkzeng/Pangolin architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the Pangolin model. Defines the number of different tokens that can be represented by the input_ids passed when calling [PangolinModel]. Defaults to 5 (A, C, G, T, N).

5

context

int

The length of the context window. The input sequence is padded with zeros of length context // 2 on each side, and the encoder trims the same amount before the prediction head.

10000

hidden_size

int

Dimensionality of the encoder layers.

32

stages

list[PangolinStageConfig] | None

Configuration for each stage in the Pangolin model. Each stage is a [PangolinStageConfig] object.

None

hidden_act

str

The non-linear activation function (function or string) in the encoder. If string, "gelu", "relu", "silu" and "gelu_new" are supported.

'relu'

batch_norm_eps

float

The epsilon used by the batch normalization layers.

1e-05

batch_norm_momentum

float

The momentum used by the batch normalization layers.

0.1

num_ensemble

int

Number of replicate networks averaged inside each tissue-specific model group. The official Pangolin v2 release uses three replicates per tissue.

3

num_tissues

int

Number of tissue-specific model groups. The official release predicts four tissues (heart, liver, brain, testis), each with a splice-site score (2 channels) and a splice-site usage score (1 channel), for a total of num_tissues * 3 upstream output channels.

4

tissue_names

list[str] | None

Names for the tissue-specific output groups. Defaults to the official Pangolin v2 tissue order: heart, liver, brain, and testis.

None

num_labels

int

Number of output labels for the [TokenPredictionHead]. Defaults to 4, one per-base splice-site usage value per tissue.

4

head

HeadConfig | None

Configuration for the [TokenPredictionHead].

None

problem_type

str | None

Problem type for the token prediction head.

'regression'

output_contexts

bool

Whether to output the context vectors for each stage.

False

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import PangolinConfig, PangolinModel
>>> # Initializing a Pangolin multimolecule/pangolin style configuration
>>> configuration = PangolinConfig()
>>> # Initializing a model (with random weights) from the multimolecule/pangolin style configuration
>>> model = PangolinModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/pangolin/configuration_pangolin.py
Python
class PangolinConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`PangolinModel`][multimolecule.models.PangolinModel]. It is used to instantiate a Pangolin model according to the
    specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a
    similar configuration to that of the Pangolin [tkzeng/Pangolin](https://github.com/tkzeng/Pangolin) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Vocabulary size of the Pangolin model. Defines the number of different tokens that can be represented by
            the `input_ids` passed when calling [`PangolinModel`].
            Defaults to 5 (`A`, `C`, `G`, `T`, `N`).
        context:
            The length of the context window. The input sequence is padded with zeros of length `context // 2` on each
            side, and the encoder trims the same amount before the prediction head.
        hidden_size:
            Dimensionality of the encoder layers.
        stages:
            Configuration for each stage in the Pangolin model. Each stage is a [`PangolinStageConfig`] object.
        hidden_act:
            The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
            `"silu"` and `"gelu_new"` are supported.
        batch_norm_eps:
            The epsilon used by the batch normalization layers.
        batch_norm_momentum:
            The momentum used by the batch normalization layers.
        num_ensemble:
            Number of replicate networks averaged inside each tissue-specific model group. The official Pangolin v2
            release uses three replicates per tissue.
        num_tissues:
            Number of tissue-specific model groups. The official release predicts four tissues (heart, liver, brain,
            testis), each with a splice-site score (2 channels) and a splice-site usage score (1 channel), for a total
            of `num_tissues * 3` upstream output channels.
        tissue_names:
            Names for the tissue-specific output groups. Defaults to the official Pangolin v2 tissue order: heart,
            liver, brain, and testis.
        num_labels:
            Number of output labels for the [`TokenPredictionHead`]. Defaults to 4, one per-base splice-site usage
            value per tissue.
        head:
            Configuration for the [`TokenPredictionHead`].
        problem_type:
            Problem type for the token prediction head.
        output_contexts:
            Whether to output the context vectors for each stage.

    Examples:
        >>> from multimolecule import PangolinConfig, PangolinModel
        >>> # Initializing a Pangolin multimolecule/pangolin style configuration
        >>> configuration = PangolinConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/pangolin style configuration
        >>> model = PangolinModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "pangolin"

    # Pangolin consumes raw nucleotide sequences (`A`, `C`, `G`, `T`, `N`) with no special tokens; `N` doubles
    # as the padding token. There is no beginning/end-of-sequence token to strip in the prediction head.
    pad_token_id: int = 4
    bos_token_id: int | None = None  # type: ignore[assignment]
    eos_token_id: int | None = None  # type: ignore[assignment]
    unk_token_id: int = 4
    mask_token_id: int | None = None  # type: ignore[assignment]
    null_token_id: int | None = None  # type: ignore[assignment]

    def __init__(
        self,
        vocab_size: int = 5,
        context: int = 10000,
        hidden_size: int = 32,
        stages: list[PangolinStageConfig] | None = None,
        hidden_act: str = "relu",
        batch_norm_eps: float = 1e-5,
        batch_norm_momentum: float = 0.1,
        num_ensemble: int = 3,
        num_tissues: int = 4,
        tissue_names: list[str] | None = None,
        num_labels: int = 4,
        head: HeadConfig | None = None,
        problem_type: str | None = "regression",
        output_contexts: bool = False,
        pad_token_id: int = 4,
        bos_token_id: int | None = None,
        eos_token_id: int | None = None,
        unk_token_id: int = 4,
        mask_token_id: int | None = None,
        null_token_id: int | None = None,
        **kwargs,
    ):
        super().__init__(num_labels=num_labels, pad_token_id=pad_token_id, unk_token_id=unk_token_id, **kwargs)
        self.bos_token_id = bos_token_id  # type: ignore[assignment]
        self.eos_token_id = eos_token_id  # type: ignore[assignment]
        self.mask_token_id = mask_token_id  # type: ignore[assignment]
        self.null_token_id = null_token_id  # type: ignore[assignment]
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.context = context
        if stages is None:
            stages = [
                PangolinStageConfig(num_blocks=4, kernel_size=11, dilation=1),
                PangolinStageConfig(num_blocks=4, kernel_size=11, dilation=4),
                PangolinStageConfig(num_blocks=4, kernel_size=21, dilation=10),
                PangolinStageConfig(num_blocks=4, kernel_size=41, dilation=25),
            ]
        self.stages = [
            stage if isinstance(stage, PangolinStageConfig) else PangolinStageConfig(**stage) for stage in stages
        ]
        self.hidden_act = hidden_act
        self.batch_norm_eps = batch_norm_eps
        self.batch_norm_momentum = batch_norm_momentum
        self.num_ensemble = num_ensemble
        self.num_tissues = num_tissues
        self.tissue_names = _resolve_tissue_names(num_tissues, tissue_names)
        self.problem_type = problem_type
        if head is None:
            head = HeadConfig(num_labels=num_labels, hidden_size=hidden_size, problem_type=problem_type)
        elif not isinstance(head, HeadConfig):
            head = HeadConfig(**head)
        self.head = head
        self.output_contexts = output_contexts

        if vocab_size <= pad_token_id:
            raise ValueError(f"vocab_size ({vocab_size}) must include pad_token_id ({pad_token_id}).")
        if hidden_size <= 0:
            raise ValueError(f"hidden_size must be positive, got {hidden_size}.")
        if context <= 0 or context % 2:
            raise ValueError(f"context must be a positive even integer, got {context}.")
        if min(num_ensemble, num_tissues, num_labels) <= 0:
            raise ValueError("num_ensemble, num_tissues, and num_labels must be positive.")
        if len(self.tissue_names) != num_tissues:
            raise ValueError(f"Expected {num_tissues} tissue names, got {len(self.tissue_names)}.")
        for index, stage in enumerate(self.stages):
            if min(stage.num_blocks, stage.kernel_size, stage.dilation) <= 0:
                raise ValueError(f"Stage {index} has non-positive block, kernel, or dilation values: {stage}.")

PangolinStageConfig

Bases: FlatDict

Configuration for a single Pangolin stage.

A stage is a contiguous group of dilated residual blocks that share a kernel size and dilation, followed by a skip-connection convolution.

Parameters:

Name Type Description Default

num_blocks

Number of dilated residual blocks in the stage.

required

kernel_size

Convolution kernel size for the blocks in the stage.

required

dilation

Dilation (atrous rate) for the blocks in the stage.

required
Source code in multimolecule/models/pangolin/configuration_pangolin.py
Python
class PangolinStageConfig(FlatDict):
    r"""
    Configuration for a single Pangolin stage.

    A stage is a contiguous group of dilated residual blocks that share a kernel size and dilation, followed by a
    skip-connection convolution.

    Args:
        num_blocks:
            Number of dilated residual blocks in the stage.
        kernel_size:
            Convolution kernel size for the blocks in the stage.
        dilation:
            Dilation (atrous rate) for the blocks in the stage.
    """

    num_blocks: int = 4
    kernel_size: int = 11
    dilation: int = 1

PangolinForTokenPrediction

Bases: PangolinPreTrainedModel

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import PangolinConfig, PangolinForTokenPrediction, PangolinStageConfig
>>> stage = PangolinStageConfig(num_blocks=1, kernel_size=3, dilation=1)
>>> config = PangolinConfig(context=4, num_tissues=1, num_ensemble=1, num_labels=1, stages=[stage])
>>> model = PangolinForTokenPrediction(config)
>>> input_ids = torch.randint(5, (1, 5))
>>> output = model(input_ids, labels=torch.rand(1, 5, 1))
>>> output["logits"].shape
torch.Size([1, 5, 1])
>>> output["loss"]
tensor(..., grad_fn=<MseLossBackward0>)
Source code in multimolecule/models/pangolin/modeling_pangolin.py
Python
class PangolinForTokenPrediction(PangolinPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import PangolinConfig, PangolinForTokenPrediction, PangolinStageConfig
        >>> stage = PangolinStageConfig(num_blocks=1, kernel_size=3, dilation=1)
        >>> config = PangolinConfig(context=4, num_tissues=1, num_ensemble=1, num_labels=1, stages=[stage])
        >>> model = PangolinForTokenPrediction(config)
        >>> input_ids = torch.randint(5, (1, 5))
        >>> output = model(input_ids, labels=torch.rand(1, 5, 1))
        >>> output["logits"].shape
        torch.Size([1, 5, 1])
        >>> output["loss"]  # doctest:+ELLIPSIS
        tensor(..., grad_fn=<MseLossBackward0>)
    """

    def __init__(self, config: PangolinConfig):
        super().__init__(config)
        self.model = PangolinModel(config)
        self.token_head = TokenPredictionHead(config)
        self.head_config = self.token_head.config

        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Tuple[Tensor, ...] | PangolinTokenPredictorOutput:
        head_attention_mask = attention_mask
        if input_ids is None and inputs_embeds is not None and head_attention_mask is None:
            if isinstance(inputs_embeds, NestedTensor):
                head_attention_mask = inputs_embeds.mask
            else:
                head_attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.int, device=inputs_embeds.device)

        outputs = self.model(
            input_ids,
            attention_mask=head_attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        output = self.token_head(outputs, head_attention_mask, input_ids, labels)
        logits, loss = output.logits, output.loss

        return PangolinTokenPredictorOutput(
            loss=loss,
            logits=logits,
            contexts=outputs.contexts,
            hidden_states=outputs.hidden_states,
        )

PangolinModel

Bases: PangolinPreTrainedModel

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import PangolinConfig, PangolinModel, PangolinStageConfig
>>> stage = PangolinStageConfig(num_blocks=1, kernel_size=3, dilation=1)
>>> config = PangolinConfig(context=4, num_tissues=1, num_ensemble=1, stages=[stage])
>>> model = PangolinModel(config)
>>> input_ids = torch.randint(5, (1, 5))
>>> output = model(input_ids)
>>> output["last_hidden_state"].shape
torch.Size([1, 5, 32])
>>> output["logits"].shape
torch.Size([1, 5, 3])
Source code in multimolecule/models/pangolin/modeling_pangolin.py
Python
class PangolinModel(PangolinPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import PangolinConfig, PangolinModel, PangolinStageConfig
        >>> stage = PangolinStageConfig(num_blocks=1, kernel_size=3, dilation=1)
        >>> config = PangolinConfig(context=4, num_tissues=1, num_ensemble=1, stages=[stage])
        >>> model = PangolinModel(config)
        >>> input_ids = torch.randint(5, (1, 5))
        >>> output = model(input_ids)
        >>> output["last_hidden_state"].shape
        torch.Size([1, 5, 32])
        >>> output["logits"].shape
        torch.Size([1, 5, 3])
    """

    def __init__(self, config: PangolinConfig):
        super().__init__(config)
        self.config = config
        self.gradient_checkpointing = False
        self.embeddings = PangolinEmbedding(config)
        # The official Pangolin release uses one replicate ensemble per tissue output group.
        self.members = nn.ModuleList(
            [
                nn.ModuleList([PangolinModule(config) for _ in range(config.num_ensemble)])
                for _ in range(config.num_tissues)
            ]
        )

        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> PangolinModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        output_contexts = kwargs.get("output_contexts", self.config.output_contexts)
        output_hidden_states = kwargs.get("output_hidden_states", self.config.output_hidden_states)
        record_contexts = bool(output_contexts) or bool(output_hidden_states)
        kwargs["output_contexts"] = record_contexts
        kwargs["output_hidden_states"] = record_contexts

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )

        member_outputs = [
            [member(embedding_output, **kwargs) for member in tissue_members] for tissue_members in self.members
        ]
        flat_outputs = [output for tissue_outputs in member_outputs for output in tissue_outputs]

        last_hidden_state = _average_tensors([out.last_hidden_state for out in flat_outputs])
        tissue_logits = []
        for tissue, tissue_outputs in enumerate(member_outputs):
            logits = _average_tensors([out.logits for out in tissue_outputs])
            start = tissue * 3
            tissue_logits.append(logits[..., start : start + 3])
        logits = torch.cat(tissue_logits, dim=-1)

        contexts: tuple[Tensor, ...] | None = None
        if record_contexts:
            per_member_contexts = [out.contexts for out in flat_outputs if out.contexts is not None]
            if per_member_contexts:
                num_stages = len(per_member_contexts[0])
                contexts = tuple(_average_tensors([m[idx] for m in per_member_contexts]) for idx in range(num_stages))

        return PangolinModelOutput(
            last_hidden_state=last_hidden_state,
            logits=logits,
            contexts=contexts if output_contexts else None,
            hidden_states=contexts if output_hidden_states else None,
        )

    @property
    def output_channels(self) -> list[str]:
        channels = []
        for tissue in self.config.tissue_names:
            channels.extend(
                [
                    f"{tissue}_no_splice",
                    f"{tissue}_splice_site",
                    f"{tissue}_usage",
                ]
            )
        return channels

    def postprocess(self, outputs: PangolinModelOutput | ModelOutput | Tensor) -> tuple[Tensor, list[str]]:
        r"""
        Return Pangolin splice-site scores with semantic tissue channel names.

        Pangolin logits already contain probability-like outputs from the original head: two softmax splice-site
        channels and one sigmoid usage channel for each tissue. This method attaches the model-defined tissue channel
        names so direct model users and pipelines share the same output semantics.

        Args:
            outputs: The output of [`PangolinModel`][multimolecule.models.PangolinModel], or its `logits` tensor.

        Returns:
            A tuple of `(scores, channels)`, where `scores` has shape `(batch_size, sequence_length, num_tissues * 3)`
            and `channels` follows `config.tissue_names`.
        """
        logits = outputs if isinstance(outputs, Tensor) else outputs["logits"]
        return logits, self.output_channels

postprocess

Python
postprocess(
    outputs: PangolinModelOutput | ModelOutput | Tensor,
) -> tuple[Tensor, list[str]]

Return Pangolin splice-site scores with semantic tissue channel names.

Pangolin logits already contain probability-like outputs from the original head: two softmax splice-site channels and one sigmoid usage channel for each tissue. This method attaches the model-defined tissue channel names so direct model users and pipelines share the same output semantics.

Parameters:

Name Type Description Default
outputs
PangolinModelOutput | ModelOutput | Tensor

The output of PangolinModel, or its logits tensor.

required

Returns:

Type Description
Tensor

A tuple of (scores, channels), where scores has shape (batch_size, sequence_length, num_tissues * 3)

list[str]

and channels follows config.tissue_names.

Source code in multimolecule/models/pangolin/modeling_pangolin.py
Python
def postprocess(self, outputs: PangolinModelOutput | ModelOutput | Tensor) -> tuple[Tensor, list[str]]:
    r"""
    Return Pangolin splice-site scores with semantic tissue channel names.

    Pangolin logits already contain probability-like outputs from the original head: two softmax splice-site
    channels and one sigmoid usage channel for each tissue. This method attaches the model-defined tissue channel
    names so direct model users and pipelines share the same output semantics.

    Args:
        outputs: The output of [`PangolinModel`][multimolecule.models.PangolinModel], or its `logits` tensor.

    Returns:
        A tuple of `(scores, channels)`, where `scores` has shape `(batch_size, sequence_length, num_tissues * 3)`
        and `channels` follows `config.tissue_names`.
    """
    logits = outputs if isinstance(outputs, Tensor) else outputs["logits"]
    return logits, self.output_channels

PangolinModelOutput dataclass

Bases: ModelOutput

Base class for outputs of the Pangolin model.

Parameters:

Name Type Description Default

last_hidden_state

`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`

Per-position encoder representation, averaged across ensemble members. Consumed by [TokenPredictionHead].

None

logits

`torch.FloatTensor` of shape `(batch_size, sequence_length, num_tissues * 3)`

Original Pangolin per-tissue splice-site score (softmax, 2 channels) and splice-site usage score (sigmoid, 1 channel) outputs, averaged across ensemble members.

None
Source code in multimolecule/models/pangolin/modeling_pangolin.py
Python
@dataclass
class PangolinModelOutput(ModelOutput):
    """
    Base class for outputs of the Pangolin model.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Per-position encoder representation, averaged across ensemble members. Consumed by
            [`TokenPredictionHead`].
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_tissues * 3)`):
            Original Pangolin per-tissue splice-site score (softmax, 2 channels) and splice-site usage score (sigmoid,
            1 channel) outputs, averaged across ensemble members.
        contexts (`tuple(torch.FloatTensor)`, *optional*, returned when `output_contexts=True` is passed or when
            `config.output_contexts=True`):
            Tuple of `torch.FloatTensor` (one per stage of the encoder) of shape `(batch_size, sequence_length,
            hidden_size)`. Skip vectors recorded after each stage, averaged across ensemble members.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
            when `config.output_hidden_states=True`):
            Same content as `contexts`; provided for compatibility with the Transformers hidden-states convention.
    """

    last_hidden_state: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    contexts: tuple[torch.FloatTensor, ...] | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None

PangolinPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/pangolin/modeling_pangolin.py
Python
class PangolinPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = PangolinConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["PangolinBlock"]

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, nn.Conv1d):
            init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)
        # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
        elif isinstance(module, nn.Linear):
            init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)
        elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm, nn.GroupNorm)):
            init.ones_(module.weight)
            init.zeros_(module.bias)

PangolinTokenPredictorOutput dataclass

Bases: ModelOutput

Base class for outputs of Pangolin token prediction models.

Parameters:

Name Type Description Default

loss

`torch.FloatTensor`, *optional*, returned when `labels` is provided

Token prediction loss.

None

logits

`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`

Per-nucleotide prediction outputs.

None

contexts

`tuple(torch.FloatTensor)`, *optional*, returned when `output_contexts=True`

Per-stage context representations.

None

hidden_states

`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`

Per-stage context representations.

None
Source code in multimolecule/models/pangolin/modeling_pangolin.py
Python
@dataclass
class PangolinTokenPredictorOutput(ModelOutput):
    """
    Base class for outputs of Pangolin token prediction models.

    Args:
        loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
            Token prediction loss.
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`):
            Per-nucleotide prediction outputs.
        contexts (`tuple(torch.FloatTensor)`, *optional*, returned when `output_contexts=True`):
            Per-stage context representations.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
            Per-stage context representations.
    """

    loss: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    contexts: tuple[torch.FloatTensor, ...] | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None