Skip to content

SpTransformer

Transformer network for predicting tissue-specific splicing from pre-mRNA sequences.

Disclaimer

This is an UNOFFICIAL implementation of SpliceTransformer predicts tissue-specific splicing linked to human diseases by Ningyuan You et al.

The OFFICIAL repository of SpliceTransformer (SpTransformer) is at ShenLab-Genomics/SpliceTransformer.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing SpTransformer did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

SpTransformer (SpliceTransformer) is a deep neural network that predicts tissue-specific splicing from primary pre-mRNA sequence. It combines two pretrained SpliceAI-style dilated-residual convolutional feature extractors with a trainable input-projection path; the concatenated features are processed by a Sinkhorn transformer attention block with axial positional embeddings. For each position the network predicts a 3-channel splice-site score (no-splice / acceptor / donor) and a per-position splice-site usage score across 15 human tissues. The model uses a fixed flanking context of 4,000 nucleotides on each side of every predicted position. SpTransformer is typically used to estimate the effect of genetic variants on tissue-specific splicing by scoring reference and alternate sequences and taking the difference. Please refer to the Training Details section for more information on the training process.

Model Specification

Num Layers Hidden Size Num Heads Intermediate Size Max Seq Len Num Parameters (M) FLOPs (G) MACs (G) Context
8 256 8 1024 8192 17.07 290.72 144.65 4000

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

RNA Splicing Site Prediction

You can use this model directly to predict per-nucleotide tissue-specific splicing of a pre-mRNA sequence:

Python
1
2
3
4
5
6
7
8
>>> from multimolecule import DnaTokenizer, SpTransformerModel

>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/sptransformer")
>>> model = SpTransformerModel.from_pretrained("multimolecule/sptransformer")
>>> output = model(tokenizer("AGCAGTCATTATGGCGAA", return_tensors="pt")["input_ids"])

>>> output.keys()
odict_keys(['last_hidden_state', 'logits'])

The logits tensor reproduces the original SpTransformer output: a 3-channel splice-site score (no-splice / acceptor / donor) and a per-tissue (15 tissues) splice-site usage score for each position.

Downstream Use

Token Prediction

You can fine-tune SpTransformer for per-nucleotide tissue-specific splicing regression with SpTransformerForTokenPrediction, which adds a shared token prediction head on top of the backbone.

Interface

  • Input length: variable pre-mRNA sequence
  • Flanking context: fixed 4,000 nt on each side of every predicted position
  • Padding: ends padded with N
  • Output: per-position 3-channel splice-site score (no-splice / acceptor / donor) + per-tissue (15 tissues) splice-site usage score
  • Attention recording: opt-in via output_attentions=True; returns faithful sparse-attention maps — see Faithful Sparse-Attention Exposure

Faithful Sparse-Attention Exposure

SpTransformer’s attention block does not compute dense self-attention. Each layer ([SpTransformerSelfAttention][multimolecule.models.sptransformer.modeling_sptransformer.SpTransformerSelfAttention]) splits its heads into two groups with fundamentally different sparse-attention structures:

  • Windowed-local heads — each window of bucket_size tokens attends only to itself plus the immediately preceding and following window (a look_backward=1, look_forward=1 look-around). Boundary positions are masked.
  • Sinkhorn sorted-bucket heads — each query bucket attends to the concatenation of (a) one sorted / reordered key bucket selected by a parameter-free attention-sort net (differentiable_topk(R, k=1)) and (b) its own local bucket.

Because these two patterns operate on different key axes, there is no single dense (batch, heads, sequence, sequence) tensor that faithfully represents the computation. Materialising a zero-filled sequence x sequence grid would be a misleading interpretability artifact, so this model does not expose one.

Instead, attention recording is opt-in and faithful. Passing output_attentions=True (or setting config.output_attentions=True) returns, for every attention layer, a SpTransformerAttentionMap holding the actual softmax weights used in the forward pass plus the indexing/permutation needed to map them back to absolute sequence positions:

  • local_attentions (B, num_local_heads, num_windows, W, (look_backward + 1 + look_forward) * W) — the real per-window softmax weights; padded look-around columns carry weight 0.
  • local_key_positions (num_windows, (look_backward + 1 + look_forward) * W) — absolute source position of every local key-axis column (-1 marks padded columns).
  • sinkhorn_attentions (B, num_sinkhorn_heads, num_buckets, W, 2 * W) — the real per-bucket softmax weights over the [reordered-bucket | own-bucket] key axis.
  • sinkhorn_reorder (B, num_sinkhorn_heads, num_buckets, num_buckets) — the exact bucket-permutation matrix; for query bucket u, the nonzero column v of row u says the reordered key bucket (columns 0:W of sinkhorn_attentions) is source bucket v (absolute positions v*W : v*W + W).
  • scalar metadata: bucket_size, look_backward, look_forward, num_local_heads, num_sinkhorn_heads, sequence_length.

W is bucket_size; local heads come first along the head axis, Sinkhorn heads second. These are structured block weights, not dense attention matrices — re-deriving the per-type attention output by contracting these exact weights with the (block-gathered) values reproduces the layer output exactly. Recording is opt-in, so the default forward path and its numerics are byte-for-byte unchanged.

Python
>>> import torch
>>> from multimolecule import SpTransformerConfig, SpTransformerModel
>>> config = SpTransformerConfig(bucket_size=4, max_seq_len=16, context=2, num_hidden_layers=2)
>>> model = SpTransformerModel(config)
>>> output = model(torch.randint(config.vocab_size, (1, 16)), output_attentions=True)
>>> layer0 = output.attentions[0]
>>> layer0.local_attentions.shape
torch.Size([1, 2, 4, 4, 12])
>>> layer0.sinkhorn_attentions.shape
torch.Size([1, 6, 4, 4, 8])
>>> layer0.sinkhorn_reorder.shape
torch.Size([1, 6, 4, 4])

Training Details

SpTransformer was trained to predict tissue-specific splicing from primary pre-mRNA sequence.

Training Data

SpTransformer was trained on splicing measurements derived from RNA-seq data across 15 human tissues, using gene annotations from GENCODE, together with multi-species sequence data. The two convolutional feature extractors were pre-trained as SpliceAI-style splice-site predictors and remain trainable submodules for downstream fine-tuning. For each predicted nucleotide, a sequence window centered on that nucleotide was used, with the flanking context padded with N (unknown nucleotide) when near transcript ends.

Training Procedure

Pre-training

The model was trained to minimize a combination of cross-entropy loss over splice-site classification and a regression loss over per-tissue splice-site usage, comparing predictions against measurements derived from RNA-seq.

Citation

BibTeX
@article{You2024,
  author    = {You, Ningyuan and Liu, Chang and Gu, Yuxin and Wang, Rong and Jia, Hanying and Zhang, Tianyun and Jiang, Song and Shi, Jinsong and Chen, Ming and Guan, Min-Xin and Sun, Siqi and Pei, Shanshan and Liu, Zhihong and Shen, Ning},
  title     = {{SpliceTransformer predicts tissue-specific splicing linked to human diseases}},
  journal   = {Nature Communications},
  year      = {2024},
  volume    = {15},
  number    = {1},
  pages     = {9129},
  month     = {oct},
  doi       = {10.1038/s41467-024-53088-6},
  issn      = {2041-1723},
  url       = {https://doi.org/10.1038/s41467-024-53088-6}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the SpliceTransformer paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.sptransformer

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

SpTransformerConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a SpTransformerModel. It is used to instantiate a SpTransformer model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the SpliceTransformer ShenLab-Genomics/SpliceTransformer architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the SpTransformer model. Defines the number of different tokens that can be represented by the input_ids passed when calling [SpTransformerModel]. Defaults to 5 (A, C, G, T, N).

5

context

int

The length of the context window. The encoder consumes context nucleotides of flanking context on each side of every predicted position.

4000

hidden_size

int

Dimensionality of the trainable input-projection path.

128

encoders

list[SpTransformerFeatureEncoderConfig] | None

Configuration for each SpliceAI-style convolutional feature encoder. Each encoder is a [SpTransformerFeatureEncoderConfig] object.

None

attention_hidden_size

int

Dimensionality of the Sinkhorn transformer attention block.

256

num_hidden_layers

int

Number of layers in the Sinkhorn transformer attention block.

8

num_attention_heads

int

Number of attention heads in the Sinkhorn transformer attention block.

8

num_local_attention_heads

int

Number of attention heads that use local (windowed) attention instead of Sinkhorn attention.

2

intermediate_size

int

Dimensionality of the feed-forward layers in the attention block.

1024

bucket_size

int

Token bucket size for Sinkhorn / local attention.

64

max_seq_len

int

Maximum sequence length consumed by the attention block. The concatenated features are center-cropped or padded to this length before the attention block.

8192

num_splice_labels

int

Number of splice-site score channels predicted by the original output head (no-splice, acceptor, donor).

3

num_tissues

int

Number of tissues for which per-position splice-site usage is predicted by the original output head.

15

tissue_names

list[str] | None

Names for the per-tissue splice-site usage channels. Defaults to tissue_0, tissue_1, …

None

hidden_act

str

The non-linear activation function (function or string) in the SpliceAI-style feature encoders.

'relu'

intermediate_act

str

The non-linear activation function (function or string) in the transformer feed-forward layers.

'gelu'

batch_norm_eps

float

The epsilon used by the batch normalization layers.

1e-05

batch_norm_momentum

float

The momentum used by the batch normalization layers.

0.1

num_labels

int

Number of output labels for the [TokenPredictionHead]. Defaults to 15, one per-position tissue-specific splice-site usage value.

15

head

HeadConfig | None

Configuration for the [TokenPredictionHead].

None

problem_type

str | None

Problem type for the token prediction head.

'regression'

output_contexts

bool

Whether to output the per-position attention-block representation.

False

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import SpTransformerConfig, SpTransformerModel
>>> # Initializing a SpTransformer multimolecule/sptransformer style configuration
>>> configuration = SpTransformerConfig()
>>> # Initializing a model (with random weights) from the multimolecule/sptransformer style configuration
>>> model = SpTransformerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/sptransformer/configuration_sptransformer.py
Python
class SpTransformerConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`SpTransformerModel`][multimolecule.models.SpTransformerModel]. It is used to instantiate a SpTransformer model
    according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the SpliceTransformer
    [ShenLab-Genomics/SpliceTransformer](https://github.com/ShenLab-Genomics/SpliceTransformer) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Vocabulary size of the SpTransformer model. Defines the number of different tokens that can be represented
            by the `input_ids` passed when calling [`SpTransformerModel`].
            Defaults to 5 (`A`, `C`, `G`, `T`, `N`).
        context:
            The length of the context window. The encoder consumes `context` nucleotides of flanking context on each
            side of every predicted position.
        hidden_size:
            Dimensionality of the trainable input-projection path.
        encoders:
            Configuration for each SpliceAI-style convolutional feature encoder. Each encoder is a
            [`SpTransformerFeatureEncoderConfig`] object.
        attention_hidden_size:
            Dimensionality of the Sinkhorn transformer attention block.
        num_hidden_layers:
            Number of layers in the Sinkhorn transformer attention block.
        num_attention_heads:
            Number of attention heads in the Sinkhorn transformer attention block.
        num_local_attention_heads:
            Number of attention heads that use local (windowed) attention instead of Sinkhorn attention.
        intermediate_size:
            Dimensionality of the feed-forward layers in the attention block.
        bucket_size:
            Token bucket size for Sinkhorn / local attention.
        max_seq_len:
            Maximum sequence length consumed by the attention block. The concatenated features are
            center-cropped or padded to this length before the attention block.
        num_splice_labels:
            Number of splice-site score channels predicted by the original output head (no-splice, acceptor,
            donor).
        num_tissues:
            Number of tissues for which per-position splice-site usage is predicted by the original output head.
        tissue_names:
            Names for the per-tissue splice-site usage channels. Defaults to `tissue_0`, `tissue_1`, ...
        hidden_act:
            The non-linear activation function (function or string) in the SpliceAI-style feature encoders.
        intermediate_act:
            The non-linear activation function (function or string) in the transformer feed-forward layers.
        batch_norm_eps:
            The epsilon used by the batch normalization layers.
        batch_norm_momentum:
            The momentum used by the batch normalization layers.
        num_labels:
            Number of output labels for the [`TokenPredictionHead`]. Defaults to 15, one per-position
            tissue-specific splice-site usage value.
        head:
            Configuration for the [`TokenPredictionHead`].
        problem_type:
            Problem type for the token prediction head.
        output_contexts:
            Whether to output the per-position attention-block representation.

    Examples:
        >>> from multimolecule import SpTransformerConfig, SpTransformerModel
        >>> # Initializing a SpTransformer multimolecule/sptransformer style configuration
        >>> configuration = SpTransformerConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/sptransformer style configuration
        >>> model = SpTransformerModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "sptransformer"

    # SpTransformer consumes raw nucleotide sequences (`A`, `C`, `G`, `T`, `N`) with no special tokens; `N`
    # doubles as the padding token. Converted checkpoints map the `N` input channel to zero weights.
    pad_token_id: int = 4
    bos_token_id: int | None = None  # type: ignore[assignment]
    eos_token_id: int | None = None  # type: ignore[assignment]
    unk_token_id: int = 4
    mask_token_id: int | None = None  # type: ignore[assignment]
    null_token_id: int | None = None  # type: ignore[assignment]

    def __init__(
        self,
        vocab_size: int = 5,
        context: int = 4000,
        hidden_size: int = 128,
        encoders: list[SpTransformerFeatureEncoderConfig] | None = None,
        attention_hidden_size: int = 256,
        num_hidden_layers: int = 8,
        num_attention_heads: int = 8,
        num_local_attention_heads: int = 2,
        intermediate_size: int = 1024,
        bucket_size: int = 64,
        max_seq_len: int = 8192,
        num_splice_labels: int = 3,
        num_tissues: int = 15,
        tissue_names: list[str] | None = None,
        hidden_act: str = "relu",
        intermediate_act: str = "gelu",
        batch_norm_eps: float = 1e-5,
        batch_norm_momentum: float = 0.1,
        num_labels: int = 15,
        head: HeadConfig | None = None,
        problem_type: str | None = "regression",
        output_contexts: bool = False,
        pad_token_id: int = 4,
        bos_token_id: int | None = None,
        eos_token_id: int | None = None,
        unk_token_id: int = 4,
        mask_token_id: int | None = None,
        null_token_id: int | None = None,
        **kwargs,
    ):
        super().__init__(
            num_labels=num_labels,
            pad_token_id=pad_token_id,
            unk_token_id=unk_token_id,
            **kwargs,
        )
        self.bos_token_id = bos_token_id  # type: ignore[assignment]
        self.eos_token_id = eos_token_id  # type: ignore[assignment]
        self.mask_token_id = mask_token_id  # type: ignore[assignment]
        self.null_token_id = null_token_id  # type: ignore[assignment]
        self.vocab_size = vocab_size
        self.context = context
        self.hidden_size = hidden_size
        if encoders is None:
            encoders = [
                SpTransformerFeatureEncoderConfig(hidden_size=128),
                SpTransformerFeatureEncoderConfig(hidden_size=64),
            ]
        self.encoders = [
            (
                encoder
                if isinstance(encoder, SpTransformerFeatureEncoderConfig)
                else SpTransformerFeatureEncoderConfig(**encoder)
            )
            for encoder in encoders
        ]
        self.attention_hidden_size = attention_hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_local_attention_heads = num_local_attention_heads
        self.intermediate_size = intermediate_size
        self.bucket_size = bucket_size
        self.max_seq_len = max_seq_len
        self.num_splice_labels = num_splice_labels
        self.num_tissues = num_tissues
        self.tissue_names = _resolve_tissue_names(num_tissues, tissue_names)
        self.hidden_act = hidden_act
        self.intermediate_act = intermediate_act
        self.batch_norm_eps = batch_norm_eps
        self.batch_norm_momentum = batch_norm_momentum
        self.problem_type = problem_type
        if head is None:
            head = HeadConfig(num_labels=num_labels, hidden_size=attention_hidden_size, problem_type=problem_type)
        elif not isinstance(head, HeadConfig):
            head = HeadConfig(**head)
        self.head = head
        self.output_contexts = output_contexts

        if pad_token_id is not None and vocab_size <= pad_token_id:
            raise ValueError(f"vocab_size ({vocab_size}) must include pad_token_id ({pad_token_id}).")
        if context < 0:
            raise ValueError(f"context must be non-negative, got {context}.")
        min_dimension = min(
            hidden_size,
            attention_hidden_size,
            num_hidden_layers,
            num_attention_heads,
            intermediate_size,
            bucket_size,
            max_seq_len,
            num_splice_labels,
            num_tissues,
            num_labels,
        )
        if min_dimension <= 0:
            raise ValueError(
                "hidden_size, attention_hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, "
                "bucket_size, max_seq_len, num_splice_labels, num_tissues, and num_labels must be positive."
            )
        if len(self.tissue_names) != num_tissues:
            raise ValueError(f"Expected {num_tissues} tissue names, got {len(self.tissue_names)}.")
        if num_local_attention_heads < 0:
            raise ValueError(f"num_local_attention_heads must be non-negative, got {num_local_attention_heads}.")
        if attention_hidden_size % num_attention_heads != 0:
            raise ValueError(
                f"attention_hidden_size ({attention_hidden_size}) must be divisible by "
                f"num_attention_heads ({num_attention_heads})"
            )
        if num_local_attention_heads > num_attention_heads:
            raise ValueError(
                f"num_local_attention_heads ({num_local_attention_heads}) cannot exceed "
                f"num_attention_heads ({num_attention_heads})"
            )
        if max_seq_len % bucket_size != 0:
            raise ValueError(f"max_seq_len ({max_seq_len}) must be divisible by bucket_size ({bucket_size})")
        for index, encoder in enumerate(self.encoders):
            if encoder.hidden_size <= 0:
                raise ValueError(f"Encoder {index} has non-positive hidden_size: {encoder.hidden_size}.")

SpTransformerFeatureEncoderConfig

Bases: FlatDict

Configuration for a single SpliceAI-style convolutional feature encoder used by SpTransformer.

SpTransformer reuses two pre-trained dilated-residual convolutional encoders to extract per-position sequence features. Each encoder is a stack of dilated residual blocks; the feature map is taken before the encoder’s own output projections.

Parameters:

Name Type Description Default

hidden_size

Number of channels in the encoder.

required
Source code in multimolecule/models/sptransformer/configuration_sptransformer.py
Python
class SpTransformerFeatureEncoderConfig(FlatDict):
    r"""
    Configuration for a single SpliceAI-style convolutional feature encoder used by SpTransformer.

    SpTransformer reuses two pre-trained dilated-residual convolutional encoders to extract per-position
    sequence features. Each encoder is a stack of dilated residual blocks; the feature map is taken before
    the encoder's own output projections.

    Args:
        hidden_size:
            Number of channels in the encoder.
    """

    hidden_size: int = 128

SpTransformerAttentionMap dataclass

Bases: ModelOutput

Faithful, structured attention weights for one SpTransformer attention layer.

SpTransformer’s attention layer (SpTransformerSelfAttention) is not dense self-attention. It splits the heads into two groups with fundamentally different sparse-attention structures, so there is no single dense (batch, heads, seq, seq) tensor that faithfully represents the computation. Fabricating one (e.g. by scattering the block weights into a zero-filled seq x seq grid) would be a misleading interpretability artifact. Instead, this object exposes the actual softmax weights computed in the forward pass for each attention type, plus the indexing/permutation needed to map them back to absolute sequence positions.

Conventions: B = batch, S = sequence length, W = window_size = bucket_size, num_windows = S // W, num_buckets = S // W. Local heads come first along the head axis, Sinkhorn heads second, matching the split inside SpTransformerSelfAttention.

Parameters:

Name Type Description Default

bucket_size

`int`

W, the local-attention window size and Sinkhorn bucket size.

None

look_backward

`int`

number of preceding windows each local window attends to (1 upstream).

None

look_forward

`int`

number of following windows each local window attends to (1 upstream).

None

num_local_heads

`int`

number of windowed-local heads (first heads along the head axis).

None

num_sinkhorn_heads

`int`

number of Sinkhorn sorted-bucket heads (remaining heads).

None

sequence_length

`int`

S, the attention-block sequence length these weights were computed on.

None

Faithfulness guarantee: re-deriving the per-type attention output by contracting these exact softmax weights with the (block-gathered) values reproduces the layer’s attention output bit-for-bit.

Source code in multimolecule/models/sptransformer/modeling_sptransformer.py
Python
@dataclass
class SpTransformerAttentionMap(ModelOutput):
    r"""
    Faithful, structured attention weights for **one** SpTransformer attention layer.

    SpTransformer's attention layer (`SpTransformerSelfAttention`) is *not* dense self-attention. It
    splits the heads into two groups with fundamentally different sparse-attention structures, so there is **no
    single dense `(batch, heads, seq, seq)` tensor** that faithfully represents the computation. Fabricating one
    (e.g. by scattering the block weights into a zero-filled `seq x seq` grid) would be a misleading
    interpretability artifact. Instead, this object exposes the *actual* `softmax` weights computed in the
    forward pass for each attention type, plus the indexing/permutation needed to map them back to absolute
    sequence positions.

    Conventions: ``B`` = batch, ``S`` = sequence length, ``W`` = ``window_size`` = ``bucket_size``,
    ``num_windows`` = ``S // W``, ``num_buckets`` = ``S // W``. Local heads come first along the head axis,
    Sinkhorn heads second, matching the split inside `SpTransformerSelfAttention`.

    Args:
        local_attentions (`torch.FloatTensor` of shape
            `(B, num_local_heads, num_windows, W, (look_backward + 1 + look_forward) * W)`, *optional*):
            Per-window softmax weights of the *windowed-local* heads. For window ``w``, query position ``i``
            (a token at absolute position ``w * W + i``) attends to a key axis that is the concatenation of
            windows ``[w - look_backward, ..., w, ..., w + look_forward]`` (each of length ``W``). Out-of-range
            neighbour windows are zero-padded and *masked* (their softmax weight is exactly ``0``). Use
            `local_key_positions` to recover the absolute source position of every key-axis column. `None` when
            the layer has no local heads.
        local_key_positions (`torch.LongTensor` of shape
            `(num_windows, (look_backward + 1 + look_forward) * W)`, *optional*):
            Absolute source sequence position for each key-axis column of `local_attentions`. Padded
            (out-of-range) columns are marked with ``-1`` and always carry softmax weight ``0``. Shared across
            batch and heads.
        sinkhorn_attentions (`torch.FloatTensor` of shape
            `(B, num_sinkhorn_heads, num_buckets, W, 2 * W)`, *optional*):
            Per-bucket softmax weights of the *Sinkhorn sorted-bucket* heads. For query bucket ``u``, the key
            axis (length ``2 * W``) is the concatenation of (a) the *single sorted/reordered* key bucket
            selected for ``u`` (columns ``0 : W``) and (b) ``u``'s own local bucket (columns ``W : 2 * W``).
            Map columns back to sequence positions via `sinkhorn_reorder` (for the first half) and
            ``u * W + j`` (for the second half). `None` when the layer has no Sinkhorn heads.
        sinkhorn_reorder (`torch.FloatTensor` of shape `(B, num_sinkhorn_heads, num_buckets, num_buckets)`,
            *optional*):
            The bucket-permutation / sort matrix produced by the parameter-free attention-sort net
            (``differentiable_topk(R, k=1)``). Row ``u`` is a one-hot-like (weighted) row whose nonzero entry
            ``v`` means query bucket ``u``'s reordered key bucket (columns ``0 : W`` of `sinkhorn_attentions`)
            is source bucket ``v`` (absolute positions ``v * W : v * W + W``), scaled by that entry. This is
            exactly the matrix used to gather the reordered keys in the forward pass. `None` when the layer has
            no Sinkhorn heads.
        bucket_size (`int`): ``W``, the local-attention window size and Sinkhorn bucket size.
        look_backward (`int`): number of preceding windows each local window attends to (``1`` upstream).
        look_forward (`int`): number of following windows each local window attends to (``1`` upstream).
        num_local_heads (`int`): number of windowed-local heads (first heads along the head axis).
        num_sinkhorn_heads (`int`): number of Sinkhorn sorted-bucket heads (remaining heads).
        sequence_length (`int`): ``S``, the attention-block sequence length these weights were computed on.

    Faithfulness guarantee: re-deriving the per-type attention output by contracting these exact softmax
    weights with the (block-gathered) values reproduces the layer's attention output bit-for-bit.
    """

    local_attentions: torch.FloatTensor | None = None
    local_key_positions: torch.LongTensor | None = None
    sinkhorn_attentions: torch.FloatTensor | None = None
    sinkhorn_reorder: torch.FloatTensor | None = None
    bucket_size: int | None = None
    look_backward: int | None = None
    look_forward: int | None = None
    num_local_heads: int | None = None
    num_sinkhorn_heads: int | None = None
    sequence_length: int | None = None

SpTransformerForTokenPrediction

Bases: SpTransformerPreTrainedModel

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import (
...     SpTransformerConfig,
...     SpTransformerFeatureEncoderConfig,
...     SpTransformerForTokenPrediction,
... )
>>> encoder = SpTransformerFeatureEncoderConfig(hidden_size=4)
>>> config = SpTransformerConfig(
...     context=2, hidden_size=8, encoders=[encoder], attention_hidden_size=16,
...     num_hidden_layers=1, num_attention_heads=2, num_local_attention_heads=1,
...     intermediate_size=32, bucket_size=4, max_seq_len=8, num_tissues=2, num_labels=2,
... )
>>> model = SpTransformerForTokenPrediction(config)
>>> input_ids = torch.randint(5, (1, 8))
>>> output = model(input_ids, labels=torch.rand(1, 8, 2))
>>> output["logits"].shape
torch.Size([1, 8, 2])
>>> output["loss"]
tensor(..., grad_fn=<MseLossBackward0>)
Source code in multimolecule/models/sptransformer/modeling_sptransformer.py
Python
class SpTransformerForTokenPrediction(SpTransformerPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import (
        ...     SpTransformerConfig,
        ...     SpTransformerFeatureEncoderConfig,
        ...     SpTransformerForTokenPrediction,
        ... )
        >>> encoder = SpTransformerFeatureEncoderConfig(hidden_size=4)
        >>> config = SpTransformerConfig(
        ...     context=2, hidden_size=8, encoders=[encoder], attention_hidden_size=16,
        ...     num_hidden_layers=1, num_attention_heads=2, num_local_attention_heads=1,
        ...     intermediate_size=32, bucket_size=4, max_seq_len=8, num_tissues=2, num_labels=2,
        ... )
        >>> model = SpTransformerForTokenPrediction(config)
        >>> input_ids = torch.randint(5, (1, 8))
        >>> output = model(input_ids, labels=torch.rand(1, 8, 2))
        >>> output["logits"].shape
        torch.Size([1, 8, 2])
        >>> output["loss"]  # doctest:+ELLIPSIS
        tensor(..., grad_fn=<MseLossBackward0>)
    """

    def __init__(self, config: SpTransformerConfig):
        super().__init__(config)
        self.model = SpTransformerModel(config)
        self.token_head = TokenPredictionHead(config)
        self.head_config = self.token_head.config

        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Tuple[Tensor, ...] | SpTransformerTokenPredictorOutput:
        head_attention_mask = attention_mask
        if input_ids is None and inputs_embeds is not None and head_attention_mask is None:
            if isinstance(inputs_embeds, NestedTensor):
                head_attention_mask = inputs_embeds.mask
            else:
                head_attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.int, device=inputs_embeds.device)

        outputs = self.model(
            input_ids,
            attention_mask=head_attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        output = self.token_head(outputs, head_attention_mask, input_ids, labels)
        logits, loss = output.logits, output.loss

        return SpTransformerTokenPredictorOutput(
            loss=loss,
            logits=logits,
            contexts=outputs.contexts,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

SpTransformerModel

Bases: SpTransformerPreTrainedModel

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import SpTransformerConfig, SpTransformerFeatureEncoderConfig, SpTransformerModel
>>> encoder = SpTransformerFeatureEncoderConfig(hidden_size=4)
>>> config = SpTransformerConfig(
...     context=2, hidden_size=8, encoders=[encoder], attention_hidden_size=16,
...     num_hidden_layers=1, num_attention_heads=2, num_local_attention_heads=1,
...     intermediate_size=32, bucket_size=4, max_seq_len=8, num_tissues=2,
... )
>>> model = SpTransformerModel(config)
>>> input_ids = torch.randint(5, (1, 8))
>>> output = model(input_ids)
>>> output["last_hidden_state"].shape
torch.Size([1, 8, 16])
>>> output["logits"].shape
torch.Size([1, 8, 5])
Source code in multimolecule/models/sptransformer/modeling_sptransformer.py
Python
class SpTransformerModel(SpTransformerPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import SpTransformerConfig, SpTransformerFeatureEncoderConfig, SpTransformerModel
        >>> encoder = SpTransformerFeatureEncoderConfig(hidden_size=4)
        >>> config = SpTransformerConfig(
        ...     context=2, hidden_size=8, encoders=[encoder], attention_hidden_size=16,
        ...     num_hidden_layers=1, num_attention_heads=2, num_local_attention_heads=1,
        ...     intermediate_size=32, bucket_size=4, max_seq_len=8, num_tissues=2,
        ... )
        >>> model = SpTransformerModel(config)
        >>> input_ids = torch.randint(5, (1, 8))
        >>> output = model(input_ids)
        >>> output["last_hidden_state"].shape
        torch.Size([1, 8, 16])
        >>> output["logits"].shape
        torch.Size([1, 8, 5])
    """

    def __init__(self, config: SpTransformerConfig):
        super().__init__(config)
        self.config = config
        self.gradient_checkpointing = False
        self.embeddings = SpTransformerEmbedding(config)
        self.feature_encoders = nn.ModuleList(
            [SpTransformerFeatureEncoder(c["hidden_size"], config) for c in config.encoders]
        )
        self.projection = SpTransformerProjection(config)
        self.encoder = SpTransformerEncoder(config)
        self.prediction = SpTransformerPredictionHead(config)

        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        if self.config.num_splice_labels == 3:
            splice_channels = ["no_splice", "acceptor", "donor"]
        else:
            splice_channels = [f"splice_label_{index}" for index in range(self.config.num_splice_labels)]
        return splice_channels + list(self.config.tissue_names)

    def postprocess(self, outputs: SpTransformerModelOutput | ModelOutput | Tensor) -> tuple[Tensor, list[str]]:
        r"""
        Return SpTransformer splice-site probabilities and tissue-usage scores with semantic channel names.

        Args:
            outputs: The output of [`SpTransformerModel`][multimolecule.models.SpTransformerModel], or its `logits`
                tensor.

        Returns:
            A tuple of `(scores, channels)`. The splice-site channels are softmax-normalized; tissue-usage channels
            are returned in the model's native scale.
        """
        logits = outputs if isinstance(outputs, Tensor) else outputs["logits"]
        splice = logits[..., : self.config.num_splice_labels].softmax(dim=-1)
        usage = logits[..., self.config.num_splice_labels : self.config.num_splice_labels + self.config.num_tissues]
        return torch.cat([splice, usage], dim=-1), self.output_channels

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> SpTransformerModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        output_contexts = kwargs.get("output_contexts", self.config.output_contexts)
        output_hidden_states = kwargs.get("output_hidden_states", self.config.output_hidden_states)
        output_attentions = kwargs.get("output_attentions", self.config.output_attentions)

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )

        target_output_len = embedding_output.size(2) - 2 * self.config.context
        max_seq_len = self.config.max_seq_len
        odd_fix = embedding_output.size(2) & 1

        projected = self.projection(embedding_output)
        projected = _pad_or_crop_features(projected, max_seq_len, odd_fix)

        features = [feature_encoder(embedding_output) for feature_encoder in self.feature_encoders]
        if features:
            features = torch.cat(features, dim=1)
            features = _pad_or_crop_features(features, max_seq_len, odd_fix)
            hidden_state = torch.cat([features, projected], dim=1)
        else:
            hidden_state = projected
        hidden_state = self.projection.fuse(hidden_state)
        encoder_outputs = self.encoder(
            hidden_state,
            output_attentions=bool(output_attentions),
            output_hidden_states=bool(output_hidden_states),
        )
        context = encoder_outputs.last_hidden_state

        logits = self.prediction(context)

        logits = _pad_or_crop_outputs(logits, target_output_len, odd_fix)
        context = context.transpose(1, 2)
        context = _pad_or_crop_outputs(context, target_output_len, odd_fix)
        hidden_states = None
        if output_hidden_states and encoder_outputs.hidden_states is not None:
            hidden_states = tuple(
                _pad_or_crop_outputs(hidden_state, target_output_len, odd_fix)
                for hidden_state in encoder_outputs.hidden_states
            )

        return SpTransformerModelOutput(
            last_hidden_state=context,
            logits=logits,
            contexts=(context,) if output_contexts else None,
            hidden_states=hidden_states,
        )

postprocess

Python
postprocess(
    outputs: (
        SpTransformerModelOutput | ModelOutput | Tensor
    ),
) -> tuple[Tensor, list[str]]

Return SpTransformer splice-site probabilities and tissue-usage scores with semantic channel names.

Parameters:

Name Type Description Default
outputs
SpTransformerModelOutput | ModelOutput | Tensor

The output of SpTransformerModel, or its logits tensor.

required

Returns:

Type Description
Tensor

A tuple of (scores, channels). The splice-site channels are softmax-normalized; tissue-usage channels

list[str]

are returned in the model’s native scale.

Source code in multimolecule/models/sptransformer/modeling_sptransformer.py
Python
def postprocess(self, outputs: SpTransformerModelOutput | ModelOutput | Tensor) -> tuple[Tensor, list[str]]:
    r"""
    Return SpTransformer splice-site probabilities and tissue-usage scores with semantic channel names.

    Args:
        outputs: The output of [`SpTransformerModel`][multimolecule.models.SpTransformerModel], or its `logits`
            tensor.

    Returns:
        A tuple of `(scores, channels)`. The splice-site channels are softmax-normalized; tissue-usage channels
        are returned in the model's native scale.
    """
    logits = outputs if isinstance(outputs, Tensor) else outputs["logits"]
    splice = logits[..., : self.config.num_splice_labels].softmax(dim=-1)
    usage = logits[..., self.config.num_splice_labels : self.config.num_splice_labels + self.config.num_tissues]
    return torch.cat([splice, usage], dim=-1), self.output_channels

SpTransformerModelOutput dataclass

Bases: ModelOutput

Base class for outputs of the SpTransformer model.

Parameters:

Name Type Description Default

last_hidden_state

`torch.FloatTensor` of shape `(batch_size, sequence_length, attention_hidden_size)`

Per-position attention-block representation. Consumed by [TokenPredictionHead].

None

logits

`torch.FloatTensor` of shape `(batch_size, sequence_length, num_splice_labels + num_tissues)`

Original SpTransformer per-position splice-site score (no-splice / acceptor / donor) and per-tissue splice-site usage score outputs.

None
Source code in multimolecule/models/sptransformer/modeling_sptransformer.py
Python
@dataclass
class SpTransformerModelOutput(ModelOutput):
    """
    Base class for outputs of the SpTransformer model.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, attention_hidden_size)`):
            Per-position attention-block representation. Consumed by [`TokenPredictionHead`].
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_splice_labels + num_tissues)`):
            Original SpTransformer per-position splice-site score (no-splice / acceptor / donor) and per-tissue
            splice-site usage score outputs.
        contexts (`tuple(torch.FloatTensor)`, *optional*, returned when `output_contexts=True` is passed or when
            `config.output_contexts=True`):
            Tuple with the per-position attention-block representation of shape `(batch_size, sequence_length,
            attention_hidden_size)`.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
            when `config.output_hidden_states=True`):
            Attention-block hidden states before the first layer and after each layer, cropped or padded to the
            predicted sequence length. The final entry is the same representation as `last_hidden_state`.
        attentions (`tuple(SpTransformerAttentionMap)`, *optional*, returned when `output_attentions=True` is passed or
            when `config.output_attentions=True`):
            One [`SpTransformerAttentionMap`] per attention layer (in forward order). SpTransformer mixes
            *windowed-local* heads and *Sinkhorn sorted-bucket* heads, which are heterogeneous sparse attention
            patterns that **cannot** be faithfully flattened into a single dense `(batch, heads, seq, seq)` tensor.
            Each [`SpTransformerAttentionMap`] therefore exposes the *real, structured* softmax weights actually
            used in the forward pass, together with the index/permutation metadata needed to map them back to
            sequence positions. See [`SpTransformerAttentionMap`] for the exact schema. **These are not dense
            attention matrices**; treating them as such would misrepresent the computation.
    """

    last_hidden_state: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    contexts: tuple[torch.FloatTensor, ...] | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None
    attentions: tuple[SpTransformerAttentionMap, ...] | None = None

SpTransformerPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/sptransformer/modeling_sptransformer.py
Python
class SpTransformerPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = SpTransformerConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["SpTransformerLayer", "SpTransformerResidualBlock"]

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, nn.Conv1d):
            init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)
        elif isinstance(module, nn.Linear):
            init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)
        elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm, nn.GroupNorm)):
            init.ones_(module.weight)
            init.zeros_(module.bias)

SpTransformerTokenPredictorOutput dataclass

Bases: ModelOutput

Base class for outputs of SpTransformer token prediction models.

Parameters:

Name Type Description Default

loss

`torch.FloatTensor`, *optional*, returned when `labels` is provided

Token prediction loss.

None

logits

`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`

Per-nucleotide prediction outputs.

None

contexts

`tuple(torch.FloatTensor)`, *optional*, returned when `output_contexts=True`

Per-position attention-block representations.

None

hidden_states

`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`

Attention-block hidden states before the first layer and after each layer.

None

attentions

`tuple(SpTransformerAttentionMap)`, *optional*, returned when `output_attentions=True`

Structured sparse-attention weights for each attention layer.

None
Source code in multimolecule/models/sptransformer/modeling_sptransformer.py
Python
@dataclass
class SpTransformerTokenPredictorOutput(ModelOutput):
    """
    Base class for outputs of SpTransformer token prediction models.

    Args:
        loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
            Token prediction loss.
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`):
            Per-nucleotide prediction outputs.
        contexts (`tuple(torch.FloatTensor)`, *optional*, returned when `output_contexts=True`):
            Per-position attention-block representations.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
            Attention-block hidden states before the first layer and after each layer.
        attentions (`tuple(SpTransformerAttentionMap)`, *optional*, returned when `output_attentions=True`):
            Structured sparse-attention weights for each attention layer.
    """

    loss: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    contexts: tuple[torch.FloatTensor, ...] | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None
    attentions: tuple[SpTransformerAttentionMap, ...] | None = None