跳转至

OpenSpliceAI

Modular native-PyTorch reimplementation of SpliceAI for predicting pre-mRNA splice sites from primary DNA sequence.

Disclaimer

This is an UNOFFICIAL implementation of OpenSpliceAI: An efficient, modular implementation of SpliceAI enabling easy retraining on non-human species by Kuan-Hao Chao, Alan Mao et al.

The OFFICIAL repository of OpenSpliceAI is at Kuanhao-Chao/OpenSpliceAI.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing OpenSpliceAI did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

OpenSpliceAI is a deep dilated residual convolutional neural network that reimplements the SpliceAI architecture in native PyTorch. It predicts, for each nucleotide of a pre-mRNA transcript, whether the position is a splice acceptor, a splice donor, or neither. The model stacks dilated residual units with increasing kernel size and atrous rate so that a wide genomic context window contributes to each per-nucleotide prediction, while skip connections aggregate multi-scale features. OpenSpliceAI reproduces the predictive behavior of SpliceAI while providing an efficient, modular training pipeline that can be retrained on non-human species.

Variants

OpenSpliceAI ships trained model families for human MANE and four non-human species. Each family provides four flanking-context sizes. The listed Hub repositories use one deterministic seed (rs10) for each family/context pair; the other seeds are training replicates and are not exposed as separate model variants.

Family 80 nt 400 nt 2,000 nt 10,000 nt
MANE / human openspliceai-mane-80nt openspliceai-mane-400nt openspliceai-mane-2000nt openspliceai-mane-10000nt
Mouse openspliceai-mouse-80nt openspliceai-mouse-400nt openspliceai-mouse-2000nt openspliceai-mouse-10000nt
Zebrafish openspliceai-zebrafish-80nt openspliceai-zebrafish-400nt openspliceai-zebrafish-2000nt openspliceai-zebrafish-10000nt
Honeybee openspliceai-honeybee-80nt openspliceai-honeybee-400nt openspliceai-honeybee-2000nt openspliceai-honeybee-10000nt
Arabidopsis openspliceai-arabidopsis-80nt openspliceai-arabidopsis-400nt openspliceai-arabidopsis-2000nt openspliceai-arabidopsis-10000nt

Model Specification

Flanking Context Residual Blocks Hidden Size Num Parameters (M) FLOPs (G) MACs (G)
80 nt 4 32 0.09 0.95 0.47
400 nt 8 32 0.19 2.00 0.99
2,000 nt 12 32 0.36 5.03 2.50
10,000 nt 16 32 0.70 20.90 10.40

Model size is determined by flanking context and is shared across species for the same context. FLOPs and MACs are reported for a single 5,000-nucleotide output sequence.

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

RNA Splicing Site Prediction

You can use this model directly to predict the splice sites of a pre-mRNA sequence:

Python
1
2
3
4
5
6
7
8
9
>>> from multimolecule import DnaTokenizer, OpenSpliceAiForTokenPrediction

>>> model_id = "multimolecule/openspliceai-mane-10000nt"
>>> tokenizer = DnaTokenizer.from_pretrained(model_id)
>>> model = OpenSpliceAiForTokenPrediction.from_pretrained(model_id)
>>> output = model(tokenizer("AGCAGTCATTATGGCGAA", return_tensors="pt")["input_ids"])

>>> output.keys()
odict_keys(['logits'])

Each output position carries three logits corresponding to neither, acceptor, and donor.

Interface

  • Input length: variable pre-mRNA sequence
  • Flanking context: 80 / 400 / 2,000 / 10,000 nt per variant family, split evenly on both sides of every predicted position
  • Padding: sequence ends padded with N
  • Output: per-position 3-class logits (neither, acceptor, donor)

Training Details

OpenSpliceAI was trained to predict the location of splice donor and acceptor sites from primary DNA sequence, following the SpliceAI training methodology.

Training Data

The MANE variants were trained on transcripts from the GENCODE/MANE human reference annotation. The non-human variants use the species annotations released by OpenSpliceAI for mouse, zebrafish, honeybee, and Arabidopsis. For each predicted nucleotide, the model receives a flanking context of 80, 400, 2,000, or 10,000 nucleotides, split evenly across the two sides of the output sequence, with sequence ends padded with N. Annotated splice donor and acceptor sites serve as positive labels; all other positions are negative.

Training Procedure

Pre-training

The model was trained to minimize a cross-entropy loss between predicted splice-site probabilities and the reference annotation.

  • Optimizer: Adam
  • Loss: cross-entropy

Please refer to the OpenSpliceAI paper for the full training protocol and hardware details.

Citation

BibTeX
@article{chao2025openspliceai,
  author    = {Chao, Kuan-Hao and Mao, Alan and Liu, Anqi and Salzberg, Steven L and Pertea, Mihaela},
  title     = {OpenSpliceAI: An efficient, modular implementation of SpliceAI enabling easy retraining on non-human species},
  journal   = {eLife},
  volume    = 14,
  pages     = {RP107454},
  year      = 2025,
  doi       = {10.7554/eLife.107454.3},
  publisher = {eLife Sciences Publications, Ltd}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the OpenSpliceAI paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.openspliceai

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

OpenSpliceAiConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a OpenSpliceAiModel. It is used to instantiate an OpenSpliceAI model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the OpenSpliceAI Kuanhao-Chao/OpenSpliceAI openspliceai-mane 10000nt architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the OpenSpliceAI model. Defines the number of different tokens that can be represented by the input_ids passed when calling [OpenSpliceAiModel]. Defaults to 4 (the one-hot nucleotide channels A, C, G, T).

4

context

int

The length of the context window. The input sequence will be padded with zeros of length context // 2 on each side so that the per-nucleotide output keeps the input resolution.

10000

hidden_size

int

Dimensionality of the encoder layers.

32

stages

list[OpenSpliceAiStageConfig] | None

Configuration for each stage in the OpenSpliceAI model. Each stage is a [OpenSpliceAiStageConfig] object.

None

hidden_act

str

The non-linear activation function (function or string) in the encoder. String values are resolved through transformers.activations.ACT2FN.

'leaky_relu'

hidden_act_kwargs

dict[str, object] | None

Keyword arguments used when instantiating string activations. Defaults to {"negative_slope": 0.1} for "leaky_relu" to match the original OpenSpliceAI checkpoints.

None

batch_norm_eps

float

The epsilon used by the batch normalization layers.

1e-05

batch_norm_momentum

float

The momentum used by the batch normalization layers.

0.1

num_labels

int

Number of output labels (neither / acceptor / donor).

3

head

HeadConfig | None

The configuration of the prediction head.

None

output_contexts

bool

Whether to output the context vectors for each stage.

False

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import OpenSpliceAiConfig, OpenSpliceAiModel
>>> # Initializing an OpenSpliceAI multimolecule/openspliceai style configuration
>>> configuration = OpenSpliceAiConfig()
>>> # Initializing a model (with random weights) from the multimolecule/openspliceai style configuration
>>> model = OpenSpliceAiModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/openspliceai/configuration_openspliceai.py
Python
class OpenSpliceAiConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`OpenSpliceAiModel`][multimolecule.models.OpenSpliceAiModel]. It is used to instantiate an OpenSpliceAI model
    according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the OpenSpliceAI
    [Kuanhao-Chao/OpenSpliceAI](https://github.com/Kuanhao-Chao/OpenSpliceAI) `openspliceai-mane` 10000nt architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Vocabulary size of the OpenSpliceAI model. Defines the number of different tokens that can be represented
            by the `input_ids` passed when calling [`OpenSpliceAiModel`].
            Defaults to 4 (the one-hot nucleotide channels `A`, `C`, `G`, `T`).
        context:
            The length of the context window. The input sequence will be padded with zeros of length `context // 2` on
            each side so that the per-nucleotide output keeps the input resolution.
        hidden_size:
            Dimensionality of the encoder layers.
        stages:
            Configuration for each stage in the OpenSpliceAI model. Each stage is a [`OpenSpliceAiStageConfig`] object.
        hidden_act:
            The non-linear activation function (function or string) in the encoder. String values are resolved through
            `transformers.activations.ACT2FN`.
        hidden_act_kwargs:
            Keyword arguments used when instantiating string activations. Defaults to `{"negative_slope": 0.1}` for
            `"leaky_relu"` to match the original OpenSpliceAI checkpoints.
        batch_norm_eps:
            The epsilon used by the batch normalization layers.
        batch_norm_momentum:
            The momentum used by the batch normalization layers.
        num_labels:
            Number of output labels (neither / acceptor / donor).
        head:
            The configuration of the prediction head.
        output_contexts:
            Whether to output the context vectors for each stage.

    Examples:
        >>> from multimolecule import OpenSpliceAiConfig, OpenSpliceAiModel
        >>> # Initializing an OpenSpliceAI multimolecule/openspliceai style configuration
        >>> configuration = OpenSpliceAiConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/openspliceai style configuration
        >>> model = OpenSpliceAiModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "openspliceai"

    def __init__(
        self,
        vocab_size: int = 4,
        context: int = 10000,
        hidden_size: int = 32,
        stages: list[OpenSpliceAiStageConfig] | None = None,
        hidden_act: str = "leaky_relu",
        hidden_act_kwargs: dict[str, object] | None = None,
        batch_norm_eps: float = 1e-5,
        batch_norm_momentum: float = 0.1,
        num_labels: int = 3,
        head: HeadConfig | None = None,
        output_contexts: bool = False,
        bos_token_id: int | None = None,
        eos_token_id: int | None = None,
        pad_token_id: int = 4,
        **kwargs,
    ):
        # OpenSpliceAI consumes raw one-hot nucleotide channels and does not use BOS/EOS tokens;
        # `pad_token_id` points at the `N` token (the last entry of the streamline DNA alphabet).
        super().__init__(
            num_labels=num_labels,
            pad_token_id=pad_token_id,
            **kwargs,
        )
        if hidden_act_kwargs is None and hidden_act == "leaky_relu":
            hidden_act_kwargs = {"negative_slope": 0.1}
        self.bos_token_id = bos_token_id  # type: ignore[assignment]
        self.eos_token_id = eos_token_id  # type: ignore[assignment]
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.context = context
        if stages is None:
            stages = [
                OpenSpliceAiStageConfig(num_blocks=4, kernel_size=11, dilation=1),
                OpenSpliceAiStageConfig(num_blocks=4, kernel_size=11, dilation=4),
                OpenSpliceAiStageConfig(num_blocks=4, kernel_size=21, dilation=10),
                OpenSpliceAiStageConfig(num_blocks=4, kernel_size=41, dilation=25),
            ]
        self.stages = stages
        self.hidden_act = hidden_act
        self.hidden_act_kwargs = hidden_act_kwargs or {}
        self.batch_norm_eps = batch_norm_eps
        self.batch_norm_momentum = batch_norm_momentum
        # OpenSpliceAI performs per-nucleotide multi-class classification (neither / acceptor / donor).
        self.head = HeadConfig(head) if head is not None else HeadConfig(problem_type="multiclass")
        self.output_contexts = output_contexts

    @property
    def cropping(self) -> int:
        r"""Total number of context nucleotides removed by the model (``2 * sum(dilation * (kernel_size - 1))``)."""
        return 2 * sum(s["dilation"] * (s["kernel_size"] - 1) * s["num_blocks"] for s in self.stages)

cropping property

Python
cropping: int

Total number of context nucleotides removed by the model (2 * sum(dilation * (kernel_size - 1))).

OpenSpliceAiStageConfig

Bases: FlatDict

Configuration for a single OpenSpliceAI stage.

Parameters:

Name Type Description Default

num_blocks

Number of residual convolutional blocks in the stage.

required

kernel_size

Convolution kernel size for the stage.

required

dilation

Dilation (atrous) factor for the stage.

required
Source code in multimolecule/models/openspliceai/configuration_openspliceai.py
Python
class OpenSpliceAiStageConfig(FlatDict):
    r"""
    Configuration for a single OpenSpliceAI stage.

    Args:
        num_blocks:
            Number of residual convolutional blocks in the stage.
        kernel_size:
            Convolution kernel size for the stage.
        dilation:
            Dilation (atrous) factor for the stage.
    """

    num_blocks: int = 4
    kernel_size: int = 11
    dilation: int = 1

OpenSpliceAiForTokenPrediction

Bases: OpenSpliceAiPreTrainedModel

OpenSpliceAI model for per-nucleotide splice-site classification (neither / acceptor / donor).

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import OpenSpliceAiConfig, OpenSpliceAiForTokenPrediction
>>> config = OpenSpliceAiConfig()
>>> model = OpenSpliceAiForTokenPrediction(config)
>>> input_ids = torch.tensor([[0, 1, 2, 3, 4]])
>>> output = model(input_ids, labels=torch.randint(3, (1, 5)))
>>> output["logits"].shape
torch.Size([1, 5, 3])
>>> output["loss"]
tensor(..., grad_fn=<NllLossBackward0>)
Source code in multimolecule/models/openspliceai/modeling_openspliceai.py
Python
class OpenSpliceAiForTokenPrediction(OpenSpliceAiPreTrainedModel):
    """
    OpenSpliceAI model for per-nucleotide splice-site classification (neither / acceptor / donor).

    Examples:
        >>> import torch
        >>> from multimolecule import OpenSpliceAiConfig, OpenSpliceAiForTokenPrediction
        >>> config = OpenSpliceAiConfig()
        >>> model = OpenSpliceAiForTokenPrediction(config)
        >>> input_ids = torch.tensor([[0, 1, 2, 3, 4]])
        >>> output = model(input_ids, labels=torch.randint(3, (1, 5)))
        >>> output["logits"].shape
        torch.Size([1, 5, 3])
        >>> output["loss"]  # doctest:+ELLIPSIS
        tensor(..., grad_fn=<NllLossBackward0>)
    """

    def __init__(self, config: OpenSpliceAiConfig):
        super().__init__(config)
        self.model = OpenSpliceAiModel(config)
        self.token_head = TokenPredictionHead(config)
        self.head_config = self.token_head.config

        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        if self.config.num_labels == 3:
            return ["no_splice", "acceptor", "donor"]
        return [f"label_{index}" for index in range(self.config.num_labels)]

    def postprocess(self, outputs: OpenSpliceAiTokenPredictorOutput | ModelOutput | Tensor) -> tuple[Tensor, list[str]]:
        r"""
        Return OpenSpliceAI splice-site probabilities with semantic channel names.

        Args:
            outputs: The output of
                [`OpenSpliceAiForTokenPrediction`][multimolecule.models.OpenSpliceAiForTokenPrediction], or its
                `logits` tensor.

        Returns:
            A tuple of `(scores, channels)`, where `scores` is softmax-normalized over splice-site classes.
        """
        logits = outputs if isinstance(outputs, Tensor) else outputs["logits"]
        return logits.softmax(dim=-1), self.output_channels

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        output_contexts: bool | None = None,
        output_hidden_states: bool | None = None,
        **kwargs,
    ) -> OpenSpliceAiTokenPredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            output_contexts=output_contexts,
            output_hidden_states=output_hidden_states,
            return_dict=True,
            **kwargs,
        )

        output = self.token_head(outputs, attention_mask, input_ids, labels)
        logits, loss = output.logits, output.loss

        return OpenSpliceAiTokenPredictorOutput(
            loss=loss,
            logits=logits,
            contexts=outputs.contexts,
            hidden_states=outputs.hidden_states,
        )

postprocess

Python
postprocess(
    outputs: (
        OpenSpliceAiTokenPredictorOutput
        | ModelOutput
        | Tensor
    ),
) -> tuple[Tensor, list[str]]

Return OpenSpliceAI splice-site probabilities with semantic channel names.

Parameters:

Name Type Description Default
outputs
OpenSpliceAiTokenPredictorOutput | ModelOutput | Tensor

The output of OpenSpliceAiForTokenPrediction, or its logits tensor.

required

Returns:

Type Description
tuple[Tensor, list[str]]

A tuple of (scores, channels), where scores is softmax-normalized over splice-site classes.

Source code in multimolecule/models/openspliceai/modeling_openspliceai.py
Python
def postprocess(self, outputs: OpenSpliceAiTokenPredictorOutput | ModelOutput | Tensor) -> tuple[Tensor, list[str]]:
    r"""
    Return OpenSpliceAI splice-site probabilities with semantic channel names.

    Args:
        outputs: The output of
            [`OpenSpliceAiForTokenPrediction`][multimolecule.models.OpenSpliceAiForTokenPrediction], or its
            `logits` tensor.

    Returns:
        A tuple of `(scores, channels)`, where `scores` is softmax-normalized over splice-site classes.
    """
    logits = outputs if isinstance(outputs, Tensor) else outputs["logits"]
    return logits.softmax(dim=-1), self.output_channels

OpenSpliceAiModel

Bases: OpenSpliceAiPreTrainedModel

The bare OpenSpliceAI backbone producing per-nucleotide context representations.

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> import torch
>>> from multimolecule import OpenSpliceAiConfig, OpenSpliceAiModel
>>> config = OpenSpliceAiConfig()
>>> model = OpenSpliceAiModel(config)
>>> input_ids = torch.tensor([[0, 1, 2, 3, 4]])
>>> output = model(input_ids)
>>> output["last_hidden_state"].shape
torch.Size([1, 5, 32])
Source code in multimolecule/models/openspliceai/modeling_openspliceai.py
Python
class OpenSpliceAiModel(OpenSpliceAiPreTrainedModel):
    """
    The bare OpenSpliceAI backbone producing per-nucleotide context representations.

    Examples:
        >>> import torch
        >>> from multimolecule import OpenSpliceAiConfig, OpenSpliceAiModel
        >>> config = OpenSpliceAiConfig()
        >>> model = OpenSpliceAiModel(config)
        >>> input_ids = torch.tensor([[0, 1, 2, 3, 4]])
        >>> output = model(input_ids)
        >>> output["last_hidden_state"].shape
        torch.Size([1, 5, 32])
    """

    def __init__(self, config: OpenSpliceAiConfig):
        super().__init__(config)
        self.config = config
        self.embeddings = OpenSpliceAiEmbedding(config)
        self.encoder = OpenSpliceAiEncoder(config)
        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        output_contexts: bool | None = None,
        output_hidden_states: bool | None = None,
        **kwargs,
    ) -> OpenSpliceAiModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        output_contexts = output_contexts if output_contexts is not None else self.config.output_contexts
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        record_contexts = bool(output_contexts) or bool(output_hidden_states)

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )

        encoder_outputs = self.encoder(embedding_output, output_hidden_states=record_contexts)
        last_hidden_state = encoder_outputs.last_hidden_state.transpose(1, 2)
        contexts = None
        if encoder_outputs.hidden_states is not None:
            contexts = tuple(hidden_state.transpose(1, 2) for hidden_state in encoder_outputs.hidden_states)

        return OpenSpliceAiModelOutput(
            last_hidden_state=last_hidden_state,
            contexts=contexts if output_contexts else None,
            hidden_states=contexts if output_hidden_states else None,
        )

OpenSpliceAiModelOutput dataclass

Bases: ModelOutput

Base class for outputs of the OpenSpliceAI backbone.

Parameters:

Name Type Description Default

last_hidden_state

`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`

Per-nucleotide context representation after the dilated residual stack.

None

contexts

`tuple(torch.FloatTensor)`, *optional*, returned when `output_contexts=True`

Per-stage context representations.

None

hidden_states

`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`

Per-stage context representations.

None
Source code in multimolecule/models/openspliceai/modeling_openspliceai.py
Python
@dataclass
class OpenSpliceAiModelOutput(ModelOutput):
    """
    Base class for outputs of the OpenSpliceAI backbone.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Per-nucleotide context representation after the dilated residual stack.
        contexts (`tuple(torch.FloatTensor)`, *optional*, returned when `output_contexts=True`):
            Per-stage context representations.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
            Per-stage context representations.
    """

    last_hidden_state: torch.FloatTensor | None = None
    contexts: Tuple[torch.FloatTensor, ...] | None = None
    hidden_states: Tuple[torch.FloatTensor, ...] | None = None

OpenSpliceAiPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/openspliceai/modeling_openspliceai.py
Python
class OpenSpliceAiPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = OpenSpliceAiConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["OpenSpliceAiBlock"]

    def _init_weights(self, module):
        if isinstance(module, nn.Conv1d):
            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="leaky_relu")
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.BatchNorm1d):
            nn.init.constant_(module.weight, 1)
            nn.init.constant_(module.bias, 0)

OpenSpliceAiTokenPredictorOutput dataclass

Bases: ModelOutput

Base class for outputs of OpenSpliceAI token prediction models.

Parameters:

Name Type Description Default

loss

`torch.FloatTensor`, *optional*, returned when `labels` is provided

Token prediction loss.

None

logits

`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`

Per-nucleotide splice-site classification scores.

None

contexts

`tuple(torch.FloatTensor)`, *optional*, returned when `output_contexts=True`

Per-stage context representations.

None

hidden_states

`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`

Per-stage context representations.

None
Source code in multimolecule/models/openspliceai/modeling_openspliceai.py
Python
@dataclass
class OpenSpliceAiTokenPredictorOutput(ModelOutput):
    """
    Base class for outputs of OpenSpliceAI token prediction models.

    Args:
        loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
            Token prediction loss.
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`):
            Per-nucleotide splice-site classification scores.
        contexts (`tuple(torch.FloatTensor)`, *optional*, returned when `output_contexts=True`):
            Per-stage context representations.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
            Per-stage context representations.
    """

    loss: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    contexts: Tuple[torch.FloatTensor, ...] | None = None
    hidden_states: Tuple[torch.FloatTensor, ...] | None = None