Skip to content

DeltaSplice

Reference-informed prediction of alternative splicing and splicing-altering mutations from sequences.

Disclaimer

This is an UNOFFICIAL implementation of Reference-informed prediction of alternative splicing and splicing-altering mutations from sequences by Chencheng Xu, Suying Bao, et al.

The OFFICIAL repository of DeltaSplice is at chaolinzhanglab/DeltaSplice.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing DeltaSplice did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

DeltaSplice predicts splice-site usage (SSU) and splicing-altering mutation effects from sequence. The model uses a valid-convolution dilated residual encoder and three prediction modules: splice-site usage, reference-informed delta-SSU, and an auxiliary splice-site head. The official package uses the average prediction of five checkpoints for SSU and delta-SSU prediction; MultiMolecule stores the five seed checkpoints of each released data variant as internal ensemble members and returns their average prediction.

Variants

Model Specification

Variant Num Layers Hidden Size Context Ensemble Members Num Parameters (M) FLOPs (M) MACs (M)
DeltaSplice 24 64 30000 5 40.376 1642965.72 820284.36
DeltaSplice-Human 24 64 30000 5 40.376 1642965.72 820284.36

(FLOPs and MACs measured on one requested output nucleotide with the default 30 kb padded context.)

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Splice-Site Usage

Python
1
2
3
4
5
6
7
8
9
>>> from multimolecule import RnaTokenizer
>>> from multimolecule.models.deltasplice import DeltaSpliceModel

>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/deltasplice")
>>> model = DeltaSpliceModel.from_pretrained("multimolecule/deltasplice")
>>> inputs = tokenizer("AGCAGUCAUUAUGGCGAAUCUGGCAAGUA", return_tensors="pt")
>>> output = model(**inputs)
>>> output["probabilities"].shape
torch.Size([1, 30, 3])

Variant Effect

Python
>>> from multimolecule import RnaTokenizer
>>> from multimolecule.models.deltasplice import DeltaSpliceModel

>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/deltasplice")
>>> model = DeltaSpliceModel.from_pretrained("multimolecule/deltasplice")
>>> reference = tokenizer("AGCAGUCAUUAUGGCGAAUCUGGCAAGUA", return_tensors="pt")
>>> alternative = tokenizer("AGCAGUCAUUAUGGCUAAUCUGGCAAGUA", return_tensors="pt")
>>> output = model(reference["input_ids"], alternative_input_ids=alternative["input_ids"], use_reference=True)
>>> output["delta"].shape
torch.Size([1, 30, 3])

Interface

  • Input: RNA sequence tokenized with RnaTokenizer; N is encoded as zero nucleotide channels
  • Output channels: no_splice, acceptor, donor
  • Reference-only call: returns per-position splice-site usage probabilities in probabilities
  • Reference + alternative call: pass the reference sequence as input_ids and the alternate sequence as alternative_input_ids
  • Reference usage: pass reference_usage with shape (batch_size, sequence_length, 3) or omit it to use the model’s own reference usage as the reference signal

Training Details

DeltaSplice was trained to predict splice-site usage from gene sequence and to improve mutation-effect prediction by incorporating reference splice-site usage.

Training Data

The upstream repository describes training from gene_dataset.tsu.txt, which contains splice-site usage in adult brains of eight mammalian species.

Training Procedure

The official release provides five seed checkpoints with the same architecture and data split. MultiMolecule represents these seed checkpoints as internal ensemble members rather than public model variants.

Citation

BibTeX
@article{xu2024deltasplice,
  title     = {Reference-informed prediction of alternative splicing and splicing-altering mutations from sequences},
  author    = {Xu, Chencheng and Bao, Suying and Wang, Ye and Li, Wenxing and Chen, Hao and Shen, Yufeng and Jiang, Tao and Zhang, Chaolin},
  journal   = {Genome Research},
  volume    = {34},
  number    = {7},
  pages     = {1052--1065},
  year      = {2024},
  doi       = {10.1101/gr.279044.124}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If MultiMolecule supports your research, please cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the DeltaSplice paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

API Reference

DeltaSpliceConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a DeltaSpliceModel. It is used to instantiate a DeltaSplice model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a configuration similar to the official chaolinzhanglab/DeltaSplice architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the DeltaSplice one-hot input channels. Defaults to 4 (A, C, G, U); the N padding token is encoded as all-zero channels.

4

context

int

Number of flanking nucleotides represented around the requested output positions. The model pads context // 2 zero-context positions on each side, reproducing the upstream fixed-window interface while returning one output per input token.

30000

hidden_size

int

Dimensionality of the convolutional encoder.

64

layers

list[DeltaSpliceLayerConfig] | None

Configuration for each dilated residual layer. Each layer is a [DeltaSpliceLayerConfig] object.

None

hidden_act

str

The non-linear activation function (function or string) in the encoder and prediction heads.

'relu'

dropout

float

Dropout probability used between the two convolutions of each residual layer.

0.3

batch_norm_eps

float

The epsilon used by batch normalization layers.

1e-05

batch_norm_momentum

float

The momentum used by batch normalization layers.

0.1

num_ensemble

int

Number of internal checkpoint members averaged by the model. The official DeltaSplice releases provide five seed checkpoints per variant.

5

num_labels

int

Number of splice-site usage labels (no_splice, acceptor, donor). Must be 3 for the official checkpoints.

3

head

HeadConfig | None

Configuration of the optional token prediction head.

None

problem_type

str | None

Problem type for the optional token prediction head.

'regression'

output_contexts

bool

Whether to output intermediate encoder representations.

False

Examples:

Python Console Session
1
2
3
4
5
>>> from multimolecule.models.deltasplice import DeltaSpliceConfig, DeltaSpliceLayerConfig, DeltaSpliceModel
>>> layer = DeltaSpliceLayerConfig(kernel_size=3, dilation=1)
>>> configuration = DeltaSpliceConfig(context=4, hidden_size=8, layers=[layer], num_ensemble=1)
>>> model = DeltaSpliceModel(configuration)
>>> configuration = model.config
Source code in multimolecule/models/deltasplice/configuration_deltasplice.py
Python
class DeltaSpliceConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`DeltaSpliceModel`][multimolecule.models.DeltaSpliceModel]. It is used to instantiate a DeltaSplice model
    according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a configuration similar to the official
    [chaolinzhanglab/DeltaSplice](https://github.com/chaolinzhanglab/DeltaSplice) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Vocabulary size of the DeltaSplice one-hot input channels. Defaults to 4 (`A`, `C`, `G`, `U`); the `N`
            padding token is encoded as all-zero channels.
        context:
            Number of flanking nucleotides represented around the requested output positions. The model pads
            `context // 2` zero-context positions on each side, reproducing the upstream fixed-window interface while
            returning one output per input token.
        hidden_size:
            Dimensionality of the convolutional encoder.
        layers:
            Configuration for each dilated residual layer. Each layer is a [`DeltaSpliceLayerConfig`] object.
        hidden_act:
            The non-linear activation function (function or string) in the encoder and prediction heads.
        dropout:
            Dropout probability used between the two convolutions of each residual layer.
        batch_norm_eps:
            The epsilon used by batch normalization layers.
        batch_norm_momentum:
            The momentum used by batch normalization layers.
        num_ensemble:
            Number of internal checkpoint members averaged by the model. The official DeltaSplice releases provide
            five seed checkpoints per variant.
        num_labels:
            Number of splice-site usage labels (`no_splice`, `acceptor`, `donor`). Must be 3 for the official
            checkpoints.
        head:
            Configuration of the optional token prediction head.
        problem_type:
            Problem type for the optional token prediction head.
        output_contexts:
            Whether to output intermediate encoder representations.

    Examples:
        >>> from multimolecule.models.deltasplice import DeltaSpliceConfig, DeltaSpliceLayerConfig, DeltaSpliceModel
        >>> layer = DeltaSpliceLayerConfig(kernel_size=3, dilation=1)
        >>> configuration = DeltaSpliceConfig(context=4, hidden_size=8, layers=[layer], num_ensemble=1)
        >>> model = DeltaSpliceModel(configuration)
        >>> configuration = model.config
    """

    model_type = "deltasplice"

    pad_token_id: int = 4
    bos_token_id: int | None = None  # type: ignore[assignment]
    eos_token_id: int | None = None  # type: ignore[assignment]
    unk_token_id: int = 4
    mask_token_id: int | None = None  # type: ignore[assignment]
    null_token_id: int | None = None  # type: ignore[assignment]

    def __init__(
        self,
        vocab_size: int = 4,
        context: int = 30000,
        hidden_size: int = 64,
        layers: list[DeltaSpliceLayerConfig] | None = None,
        hidden_act: str = "relu",
        dropout: float = 0.3,
        batch_norm_eps: float = 1e-5,
        batch_norm_momentum: float = 0.1,
        num_ensemble: int = 5,
        num_labels: int = 3,
        head: HeadConfig | None = None,
        problem_type: str | None = "regression",
        output_contexts: bool = False,
        pad_token_id: int = 4,
        bos_token_id: int | None = None,
        eos_token_id: int | None = None,
        unk_token_id: int = 4,
        mask_token_id: int | None = None,
        null_token_id: int | None = None,
        **kwargs,
    ):
        if layers is None:
            kernels = [
                11,
                11,
                11,
                11,
                19,
                19,
                19,
                19,
                25,
                25,
                25,
                25,
                33,
                33,
                33,
                33,
                43,
                43,
                85,
                85,
                85,
                85,
                85,
                85,
            ]
            dilations = [
                1,
                1,
                1,
                1,
                1,
                1,
                1,
                1,
                2,
                2,
                2,
                2,
                8,
                8,
                8,
                8,
                16,
                16,
                16,
                16,
                16,
                16,
                32,
                32,
            ]
            layers = [
                DeltaSpliceLayerConfig(kernel_size=kernel_size, dilation=dilation)
                for kernel_size, dilation in zip(kernels, dilations)
            ]
        self.layers = [
            layer if isinstance(layer, DeltaSpliceLayerConfig) else DeltaSpliceLayerConfig(**layer) for layer in layers
        ]
        if num_labels != 3:
            raise ValueError(f"DeltaSplice emits three usage channels; `num_labels` must be 3, got {num_labels}.")
        super().__init__(
            num_labels=num_labels,
            pad_token_id=pad_token_id,
            unk_token_id=unk_token_id,
            **kwargs,
        )
        self.bos_token_id = bos_token_id  # type: ignore[assignment]
        self.eos_token_id = eos_token_id  # type: ignore[assignment]
        self.mask_token_id = mask_token_id  # type: ignore[assignment]
        self.null_token_id = null_token_id  # type: ignore[assignment]
        self.vocab_size = vocab_size
        self.context = context
        self.hidden_size = hidden_size
        self.hidden_act = hidden_act
        self.dropout = dropout
        self.batch_norm_eps = batch_norm_eps
        self.batch_norm_momentum = batch_norm_momentum
        self.num_ensemble = num_ensemble
        self.problem_type = problem_type
        if head is None:
            head = HeadConfig(num_labels=num_labels, hidden_size=hidden_size, problem_type=problem_type)
        elif not isinstance(head, HeadConfig):
            head = HeadConfig(**head)
        self.head = head
        self.output_contexts = output_contexts

        if vocab_size <= 0:
            raise ValueError(f"vocab_size must be positive, got {vocab_size}.")
        if pad_token_id is not None and pad_token_id < vocab_size:
            raise ValueError(
                f"DeltaSplice expects pad_token_id ({pad_token_id}) outside the {vocab_size} nucleotide channels."
            )
        if hidden_size <= 0:
            raise ValueError(f"hidden_size must be positive, got {hidden_size}.")
        if context <= 0 or context % 2:
            raise ValueError(f"context must be a positive even integer, got {context}.")
        if num_ensemble <= 0:
            raise ValueError(f"num_ensemble must be positive, got {num_ensemble}.")
        if not 0 <= dropout < 1:
            raise ValueError(f"dropout must be in [0, 1), got {dropout}.")
        for index, layer in enumerate(self.layers):
            if min(layer.kernel_size, layer.dilation) <= 0:
                raise ValueError(f"Layer {index} has non-positive kernel size or dilation: {layer}.")
            if layer.kernel_size % 2 == 0:
                raise ValueError(f"Layer {index} uses an even kernel size ({layer.kernel_size}); expected odd.")
        if self.context < self.convolution_reduction:
            raise ValueError(
                f"context ({self.context}) must be at least the encoder reduction ({self.convolution_reduction})."
            )
        if (self.context - self.convolution_reduction) % 2:
            raise ValueError(
                f"context ({self.context}) and encoder reduction ({self.convolution_reduction}) must have same parity."
            )

    @property
    def convolution_reduction(self) -> int:
        r"""Number of positions removed by the unpadded dilated convolutions in the encoder."""
        return 2 * sum((layer.kernel_size - 1) * layer.dilation for layer in self.layers)

    @property
    def encoder_crop(self) -> int:
        r"""Additional positions cropped on each side after the convolutional encoder."""
        return (self.context - self.convolution_reduction) // 2

convolution_reduction property

Python
convolution_reduction: int

Number of positions removed by the unpadded dilated convolutions in the encoder.

encoder_crop property

Python
encoder_crop: int

Additional positions cropped on each side after the convolutional encoder.

DeltaSpliceLayerConfig

Bases: FlatDict

Configuration for a single DeltaSplice dilated residual layer.

Parameters:

Name Type Description Default

kernel_size

Convolution kernel size used by both convolutions in the residual layer.

required

dilation

Dilation (atrous rate) used by both convolutions in the residual layer.

required
Source code in multimolecule/models/deltasplice/configuration_deltasplice.py
Python
class DeltaSpliceLayerConfig(FlatDict):
    r"""
    Configuration for a single DeltaSplice dilated residual layer.

    Args:
        kernel_size:
            Convolution kernel size used by both convolutions in the residual layer.
        dilation:
            Dilation (atrous rate) used by both convolutions in the residual layer.
    """

    kernel_size: int = 11
    dilation: int = 1

DeltaSpliceForTokenPrediction

Bases: DeltaSplicePreTrainedModel

DeltaSplice model with a shared MultiMolecule token prediction head.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule.models.deltasplice import (
...     DeltaSpliceConfig,
...     DeltaSpliceForTokenPrediction,
...     DeltaSpliceLayerConfig,
... )
>>> layer = DeltaSpliceLayerConfig(kernel_size=3, dilation=1)
>>> config = DeltaSpliceConfig(context=4, hidden_size=8, layers=[layer], num_ensemble=1)
>>> model = DeltaSpliceForTokenPrediction(config)
>>> output = model(torch.randint(5, (1, 6)), labels=torch.rand(1, 6, 3))
>>> output["logits"].shape
torch.Size([1, 6, 3])
Source code in multimolecule/models/deltasplice/modeling_deltasplice.py
Python
class DeltaSpliceForTokenPrediction(DeltaSplicePreTrainedModel):
    """
    DeltaSplice model with a shared MultiMolecule token prediction head.

    Examples:
        >>> import torch
        >>> from multimolecule.models.deltasplice import (
        ...     DeltaSpliceConfig,
        ...     DeltaSpliceForTokenPrediction,
        ...     DeltaSpliceLayerConfig,
        ... )
        >>> layer = DeltaSpliceLayerConfig(kernel_size=3, dilation=1)
        >>> config = DeltaSpliceConfig(context=4, hidden_size=8, layers=[layer], num_ensemble=1)
        >>> model = DeltaSpliceForTokenPrediction(config)
        >>> output = model(torch.randint(5, (1, 6)), labels=torch.rand(1, 6, 3))
        >>> output["logits"].shape
        torch.Size([1, 6, 3])
    """

    def __init__(self, config: DeltaSpliceConfig):
        super().__init__(config)
        self.model = DeltaSpliceModel(config)
        self.token_head = TokenPredictionHead(config)
        self.head_config = self.token_head.config

        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        reference_input_ids: Tensor | NestedTensor | None = None,
        reference_attention_mask: Tensor | None = None,
        reference_inputs_embeds: Tensor | NestedTensor | None = None,
        reference_usage: Tensor | None = None,
        labels: Tensor | None = None,
        output_contexts: bool | None = None,
        output_hidden_states: bool | None = None,
        **kwargs,
    ) -> DeltaSpliceTokenPredictorOutput:
        head_attention_mask = attention_mask
        if input_ids is None and inputs_embeds is not None and head_attention_mask is None:
            if isinstance(inputs_embeds, NestedTensor):
                head_attention_mask = inputs_embeds.mask
            else:
                head_attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.int, device=inputs_embeds.device)

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            reference_input_ids=reference_input_ids,
            reference_attention_mask=reference_attention_mask,
            reference_inputs_embeds=reference_inputs_embeds,
            reference_usage=reference_usage,
            output_contexts=output_contexts,
            output_hidden_states=output_hidden_states,
            return_dict=True,
            **kwargs,
        )

        output = self.token_head(outputs, head_attention_mask, input_ids, labels)
        logits, loss = output.logits, output.loss

        return DeltaSpliceTokenPredictorOutput(
            loss=loss,
            logits=logits,
            contexts=outputs.contexts,
            hidden_states=outputs.hidden_states,
        )

DeltaSpliceModel

Bases: DeltaSplicePreTrainedModel

DeltaSplice backbone and official five-seed ensemble for per-position splice-site usage prediction.

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> import torch
>>> from multimolecule.models.deltasplice import DeltaSpliceConfig, DeltaSpliceLayerConfig, DeltaSpliceModel
>>> layer = DeltaSpliceLayerConfig(kernel_size=3, dilation=1)
>>> config = DeltaSpliceConfig(context=4, hidden_size=8, layers=[layer], num_ensemble=1)
>>> model = DeltaSpliceModel(config)
>>> output = model(torch.randint(5, (1, 6)))
>>> output["probabilities"].shape
torch.Size([1, 6, 3])
Source code in multimolecule/models/deltasplice/modeling_deltasplice.py
Python
class DeltaSpliceModel(DeltaSplicePreTrainedModel):
    """
    DeltaSplice backbone and official five-seed ensemble for per-position splice-site usage prediction.

    Examples:
        >>> import torch
        >>> from multimolecule.models.deltasplice import DeltaSpliceConfig, DeltaSpliceLayerConfig, DeltaSpliceModel
        >>> layer = DeltaSpliceLayerConfig(kernel_size=3, dilation=1)
        >>> config = DeltaSpliceConfig(context=4, hidden_size=8, layers=[layer], num_ensemble=1)
        >>> model = DeltaSpliceModel(config)
        >>> output = model(torch.randint(5, (1, 6)))
        >>> output["probabilities"].shape
        torch.Size([1, 6, 3])
    """

    def __init__(self, config: DeltaSpliceConfig):
        super().__init__(config)
        self.config = config
        self.embeddings = DeltaSpliceEmbedding(config)
        self.members = nn.ModuleList([DeltaSpliceModule(config) for _ in range(config.num_ensemble)])

        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        if self.config.num_labels == 3:
            return ["no_splice", "acceptor", "donor"]
        return [f"label_{index}" for index in range(self.config.num_labels)]

    def postprocess(self, outputs: DeltaSpliceModelOutput | ModelOutput | Tensor) -> tuple[Tensor, list[str]]:
        r"""
        Return DeltaSplice splice-site usage probabilities with semantic channel names.

        Args:
            outputs: The output of [`DeltaSpliceModel`][multimolecule.models.DeltaSpliceModel], or its
                `probabilities` tensor.

        Returns:
            A tuple of `(scores, channels)`, where `scores` are splice-site usage probabilities, or the probability
            change (`delta`) when an alternative sequence is scored.
        """
        if isinstance(outputs, Tensor):
            return outputs, self.output_channels
        scores = outputs["delta"] if outputs.get("delta") is not None else outputs["probabilities"]
        return scores, self.output_channels

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        alternative_input_ids: Tensor | NestedTensor | None = None,
        alternative_attention_mask: Tensor | None = None,
        alternative_inputs_embeds: Tensor | NestedTensor | None = None,
        reference_input_ids: Tensor | NestedTensor | None = None,
        reference_attention_mask: Tensor | None = None,
        reference_inputs_embeds: Tensor | NestedTensor | None = None,
        reference_usage: Tensor | None = None,
        use_reference: bool | None = None,
        output_contexts: bool | None = None,
        output_hidden_states: bool | None = None,
        **kwargs,
    ) -> DeltaSpliceModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")
        if alternative_input_ids is not None and alternative_inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both alternative_input_ids and alternative_inputs_embeds at the same time"
            )
        if reference_input_ids is not None and reference_inputs_embeds is not None:
            raise ValueError("You cannot specify both reference_input_ids and reference_inputs_embeds at the same time")

        output_contexts = output_contexts if output_contexts is not None else self.config.output_contexts
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        record_contexts = bool(output_contexts) or bool(output_hidden_states)

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor
        if isinstance(alternative_input_ids, NestedTensor):
            if alternative_attention_mask is None:
                alternative_attention_mask = alternative_input_ids.mask
            alternative_input_ids = alternative_input_ids.tensor
        if isinstance(alternative_inputs_embeds, NestedTensor):
            if alternative_attention_mask is None:
                alternative_attention_mask = alternative_inputs_embeds.mask
            alternative_inputs_embeds = alternative_inputs_embeds.tensor
        if isinstance(reference_input_ids, NestedTensor):
            if reference_attention_mask is None:
                reference_attention_mask = reference_input_ids.mask
            reference_input_ids = reference_input_ids.tensor
        if isinstance(reference_inputs_embeds, NestedTensor):
            if reference_attention_mask is None:
                reference_attention_mask = reference_inputs_embeds.mask
            reference_inputs_embeds = reference_inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        alternative_embedding_output = None
        if alternative_input_ids is not None or alternative_inputs_embeds is not None:
            alternative_embedding_output = self.embeddings(
                input_ids=alternative_input_ids,
                attention_mask=alternative_attention_mask,
                inputs_embeds=alternative_inputs_embeds,
            )
        reference_embedding_output = None
        if reference_input_ids is not None or reference_inputs_embeds is not None:
            reference_embedding_output = self.embeddings(
                input_ids=reference_input_ids,
                attention_mask=reference_attention_mask,
                inputs_embeds=reference_inputs_embeds,
            )

        if alternative_embedding_output is not None and reference_embedding_output is not None:
            raise ValueError("Use either alternative_* inputs or reference_* inputs in one DeltaSplice call, not both.")

        if alternative_embedding_output is None:
            member_outputs = [
                member(
                    embedding_output,
                    reference_embeds=reference_embedding_output,
                    reference_usage=reference_usage,
                    output_contexts=record_contexts,
                    output_hidden_states=record_contexts,
                )
                for member in self.members
            ]
        else:
            member_outputs = [
                _member_variant_output(
                    member,
                    embedding_output,
                    alternative_embedding_output,
                    reference_usage=reference_usage,
                    use_reference=bool(use_reference) or reference_usage is not None,
                    output_contexts=record_contexts,
                    output_hidden_states=record_contexts,
                )
                for member in self.members
            ]

        last_hidden_state = _average_tensors([out.last_hidden_state for out in member_outputs])
        probabilities = _average_tensors([out.probabilities for out in member_outputs])
        site_probabilities = _average_tensors([out.site_probabilities for out in member_outputs])
        alternative_probabilities = _average_tensors([out.alternative_probabilities for out in member_outputs])
        delta = _average_tensors([out.delta for out in member_outputs])

        contexts: tuple[Tensor, ...] | None = None
        if record_contexts:
            per_member_contexts = [out.contexts for out in member_outputs if out.contexts is not None]
            if per_member_contexts:
                num_contexts = len(per_member_contexts[0])
                contexts = tuple(
                    _average_tensors([member_contexts[index] for member_contexts in per_member_contexts])
                    for index in range(num_contexts)
                )

        return DeltaSpliceModelOutput(
            last_hidden_state=last_hidden_state,
            probabilities=probabilities,
            site_probabilities=site_probabilities,
            alternative_probabilities=alternative_probabilities,
            delta=delta,
            contexts=contexts if output_contexts else None,
            hidden_states=contexts if output_hidden_states else None,
        )

postprocess

Python
postprocess(
    outputs: DeltaSpliceModelOutput | ModelOutput | Tensor,
) -> tuple[Tensor, list[str]]

Return DeltaSplice splice-site usage probabilities with semantic channel names.

Parameters:

Name Type Description Default

outputs

DeltaSpliceModelOutput | ModelOutput | Tensor

The output of DeltaSpliceModel, or its probabilities tensor.

required

Returns:

Type Description
Tensor

A tuple of (scores, channels), where scores are splice-site usage probabilities, or the probability

list[str]

change (delta) when an alternative sequence is scored.

Source code in multimolecule/models/deltasplice/modeling_deltasplice.py
Python
def postprocess(self, outputs: DeltaSpliceModelOutput | ModelOutput | Tensor) -> tuple[Tensor, list[str]]:
    r"""
    Return DeltaSplice splice-site usage probabilities with semantic channel names.

    Args:
        outputs: The output of [`DeltaSpliceModel`][multimolecule.models.DeltaSpliceModel], or its
            `probabilities` tensor.

    Returns:
        A tuple of `(scores, channels)`, where `scores` are splice-site usage probabilities, or the probability
        change (`delta`) when an alternative sequence is scored.
    """
    if isinstance(outputs, Tensor):
        return outputs, self.output_channels
    scores = outputs["delta"] if outputs.get("delta") is not None else outputs["probabilities"]
    return scores, self.output_channels

DeltaSpliceModelOutput dataclass

Bases: ModelOutput

Base class for outputs of the DeltaSplice model.

Parameters:

Name Type Description Default

last_hidden_state

`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`

Per-position encoder representation averaged across ensemble members.

None

probabilities

`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`

DeltaSplice splice-site usage probabilities (no_splice, acceptor, donor) averaged across ensemble members. These are softmax-normalised; DeltaSplice has no pre-softmax logit surface.

None

site_probabilities

`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`

DeltaSplice splice-site probability module outputs averaged across ensemble members.

None

delta

`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`, *optional*

alternative_probabilities - probabilities when an alternative sequence is supplied.

None

contexts

`tuple(torch.FloatTensor)`, *optional*

Per-layer valid-convolution representations cropped to the input sequence length.

None

hidden_states

`tuple(torch.FloatTensor)`, *optional*

Same content as contexts; provided for Transformers hidden-state compatibility.

None
Source code in multimolecule/models/deltasplice/modeling_deltasplice.py
Python
@dataclass
class DeltaSpliceModelOutput(ModelOutput):
    """
    Base class for outputs of the DeltaSplice model.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Per-position encoder representation averaged across ensemble members.
        probabilities (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`):
            DeltaSplice splice-site usage probabilities (`no_splice`, `acceptor`, `donor`) averaged across ensemble
            members. These are softmax-normalised; DeltaSplice has no pre-softmax logit surface.
        site_probabilities (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`):
            DeltaSplice splice-site probability module outputs averaged across ensemble members.
        alternative_probabilities
            (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`, *optional*):
            Alternative-sequence splice-site usage probabilities when an alternative sequence is supplied.
        delta (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`, *optional*):
            `alternative_probabilities - probabilities` when an alternative sequence is supplied.
        contexts (`tuple(torch.FloatTensor)`, *optional*):
            Per-layer valid-convolution representations cropped to the input sequence length.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Same content as `contexts`; provided for Transformers hidden-state compatibility.
    """

    last_hidden_state: torch.FloatTensor | None = None
    probabilities: torch.FloatTensor | None = None
    site_probabilities: torch.FloatTensor | None = None
    alternative_probabilities: torch.FloatTensor | None = None
    delta: torch.FloatTensor | None = None
    contexts: tuple[torch.FloatTensor, ...] | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None

DeltaSplicePreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/deltasplice/modeling_deltasplice.py
Python
class DeltaSplicePreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = DeltaSpliceConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["DeltaSpliceLayer"]

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        if isinstance(module, nn.Conv1d):
            init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)
        elif isinstance(module, nn.Linear):
            init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)
        elif isinstance(module, nn.BatchNorm1d) and module.affine:
            init.ones_(module.weight)
            init.zeros_(module.bias)

DeltaSpliceTokenPredictorOutput dataclass

Bases: ModelOutput

Base class for outputs of DeltaSplice token prediction models.

Parameters:

Name Type Description Default

loss

`torch.FloatTensor`, *optional*, returned when `labels` is provided

Token prediction loss.

None

logits

`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`

Per-nucleotide token prediction outputs.

None

contexts

`tuple(torch.FloatTensor)`, *optional*, returned when `output_contexts=True`

Per-layer context representations.

None

hidden_states

`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`

Per-layer context representations.

None
Source code in multimolecule/models/deltasplice/modeling_deltasplice.py
Python
@dataclass
class DeltaSpliceTokenPredictorOutput(ModelOutput):
    """
    Base class for outputs of DeltaSplice token prediction models.

    Args:
        loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
            Token prediction loss.
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`):
            Per-nucleotide token prediction outputs.
        contexts (`tuple(torch.FloatTensor)`, *optional*, returned when `output_contexts=True`):
            Per-layer context representations.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
            Per-layer context representations.
    """

    loss: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    contexts: tuple[torch.FloatTensor, ...] | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None