跳转至

MTSplice

Tissue-specific modeling of the effects of genetic variants on splicing.

Disclaimer

This is an UNOFFICIAL implementation of the MTSplice predicts effects of genetic variants on tissue-specific splicing by Jun Cheng et al.

The OFFICIAL repository of MTSplice is at gagneurlab/MMSplice_MTSplice.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing MTSplice did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

MTSplice is the tissue-specific second generation of MMSplice. It predicts the effect of genetic variants on cassette-exon splicing across 56 GTEx tissues. The cassette exon together with its flanking introns is fed into two parallel sequence towers whose outputs are combined into a per-tissue delta-logit-PSI splicing-effect vector. Please refer to the Training Details section for more information on the training process.

MTSplice is distributed as a deep four-member ensemble (mtsplice_deep0..3) and an earlier eight-member ensemble (mtsplice0..7). The default deep-family model is represented as a single deterministic model based on mtsplice_deep0.

Model Specification

Num Blocks Hidden Size Num Tissues Num Parameters FLOPs (M) MACs (M)
8 64 56 210,840 164.36 80.90

(Num Blocks is per tower; FLOPs and MACs measured on an 800 bp cassette-exon-with-flanks input.)

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Tissue Scores

Python
1
2
3
4
5
6
7
8
9
>>> import torch
>>> from multimolecule import DnaTokenizer, MtSpliceModel

>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/mtsplice")
>>> model = MtSpliceModel.from_pretrained("multimolecule/mtsplice")
>>> reference = tokenizer("agcagtcattatggcgaatctggcaagta", return_tensors="pt")
>>> output = model(**reference)
>>> output["logits"].shape
torch.Size([1, 56])

Variant Effect

Python
>>> import torch
>>> from multimolecule import DnaTokenizer, MtSpliceForSequencePrediction

>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/mtsplice")
>>> model = MtSpliceForSequencePrediction.from_pretrained("multimolecule/mtsplice")
>>> reference = tokenizer("agcagtcattatggcgaatctggcaagta", return_tensors="pt")
>>> alternative = tokenizer("agcagtcattatggctaatctggcaagta", return_tensors="pt")
>>> output = model(
...     reference["input_ids"],
...     alternative_input_ids=alternative["input_ids"],
... )
>>> output["logits"].shape
torch.Size([1, 56])

Interface

  • Input length: cassette exon with flanking intronic context (typical ~800 bp)
  • Output (reference-only call, input_ids / inputs_embeds): per-tissue score vector logits of shape (batch_size, 56)

Variant Effect

  • Reference + alternative call (also pass alternative_input_ids / alternative_inputs_embeds): additionally returns alternative_logits and per-tissue delta_logits = alternative_logits - logits
  • MtSpliceForSequencePrediction: returns per-tissue deltas (or per-tissue scores when no alternative is supplied); applies standard regression loss when labels are provided

Training Details

MTSplice was trained to predict tissue-specific percent-spliced-in (PSI) of cassette exons across GTEx tissues, building on the MMSplice modular splicing model with an added tissue-specific neural module.

Training Data

MTSplice was trained on cassette-exon PSI quantifications across 56 GTEx tissues, together with the human reference splice-site and exon sequence context. The variant-effect predictions were validated against tissue-specific splicing quantitative trait loci (sQTL) and MPRA exon-skipping data.

Training Procedure

Pre-training

The two sequence towers consume one-hot encoded DNA. A dilated-convolution stack with positional B-spline re-weighting extracts splicing features, which a dense head maps to per-tissue delta-logit-PSI. The tissue-resolved predictions are formed from the reference/alternative score deltas.

Citation

BibTeX
@article{cheng2021mtsplice,
  title     = {MTSplice predicts effects of genetic variants on tissue-specific splicing},
  author    = {Cheng, Jun and {\c{C}}elik, Muhammed Hasan and Kundaje, Anshul and Gagneur, Julien},
  journal   = {Genome Biology},
  volume    = 22,
  number    = 1,
  pages     = {94},
  year      = 2021,
  publisher = {Springer},
  doi       = {10.1186/s13059-021-02273-7}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the MTSplice paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.mtsplice

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

MtSpliceConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a MtSpliceModel. It is used to instantiate a MTSplice model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the MTSplice gagneurlab/MMSplice_MTSplice architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

MTSplice (Cheng et al. 2021) is the tissue-specific second generation of MMSplice. It scores a cassette exon together with its flanking introns through two parallel dilated-convolution towers: an acceptor (3’ splice site) tower over the upstream region and a donor (5’ splice site) tower over the downstream region. The two towers are positionally re-weighted by B-spline transformations, pooled, and combined by a small dense head into a tissue-resolved delta-logit-PSI splicing-effect score across 56 GTEx tissues.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the MTSplice model. Defines the number of feature channels derived from the one-hot encoded input_ids. Defaults to 4 (the ACGT nucleobase alphabet).

4

hidden_size

int

Number of convolution filters in the two sequence towers.

64

kernel_size

int

Kernel size of the first (stem) convolution in each tower.

11

num_blocks

int

Number of residual dilated-convolution blocks per tower.

8

block_kernel_size

int

Kernel size of the residual dilated-convolution blocks.

3

dilation_base

int

Base of the exponentially growing dilation rate; block i uses dilation dilation_base ** (i + 1).

2

acceptor_length

int

Length (in bp) of the acceptor (3’ splice site) input region, intron overhang plus exon flank.

400

donor_length

int

Length (in bp) of the donor (5’ splice site) input region, exon flank plus intron overhang.

400

spline_bases

int

Number of B-spline bases used by the positional re-weighting layers.

10

spline_degree

int

Polynomial degree of the B-spline bases.

3

mlp_size

int

Hidden size of the dense head that maps pooled features to tissue scores.

32

hidden_act

str

The non-linear activation function in the convolution towers and the dense head.

'relu'

batch_norm_eps

float

The epsilon used by the batch normalization layers. Defaults to 0.001 to match the upstream Keras BatchNormalization default.

0.001

hidden_dropout

float

The dropout probability applied before the tissue projection.

0.5

num_labels

int

Number of tissue outputs. MTSplice predicts delta-logit-PSI for the 56 GTEx tissues, so this defaults to 56.

56

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import MtSpliceConfig, MtSpliceModel
>>> # Initializing a MTSplice multimolecule/mtsplice style configuration
>>> configuration = MtSpliceConfig()
>>> # Initializing a model (with random weights) from the multimolecule/mtsplice style configuration
>>> model = MtSpliceModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/mtsplice/configuration_mtsplice.py
Python
class MtSpliceConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`MtSpliceModel`][multimolecule.models.MtSpliceModel]. It is used to instantiate a MTSplice model according to the
    specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a
    similar configuration to that of the MTSplice
    [gagneurlab/MMSplice_MTSplice](https://github.com/gagneurlab/MMSplice_MTSplice) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    MTSplice (Cheng et al. 2021) is the tissue-specific second generation of MMSplice. It scores a cassette exon
    together with its flanking introns through two parallel dilated-convolution towers: an *acceptor* (3' splice
    site) tower over the upstream region and a *donor* (5' splice site) tower over the downstream region. The two
    towers are positionally re-weighted by B-spline transformations, pooled, and combined by a small dense head into
    a tissue-resolved delta-logit-PSI splicing-effect score across 56 GTEx tissues.

    Args:
        vocab_size:
            Vocabulary size of the MTSplice model. Defines the number of feature channels derived from the one-hot
            encoded `input_ids`. Defaults to 4 (the `ACGT` nucleobase alphabet).
        hidden_size:
            Number of convolution filters in the two sequence towers.
        kernel_size:
            Kernel size of the first (stem) convolution in each tower.
        num_blocks:
            Number of residual dilated-convolution blocks per tower.
        block_kernel_size:
            Kernel size of the residual dilated-convolution blocks.
        dilation_base:
            Base of the exponentially growing dilation rate; block `i` uses dilation `dilation_base ** (i + 1)`.
        acceptor_length:
            Length (in bp) of the acceptor (3' splice site) input region, intron overhang plus exon flank.
        donor_length:
            Length (in bp) of the donor (5' splice site) input region, exon flank plus intron overhang.
        spline_bases:
            Number of B-spline bases used by the positional re-weighting layers.
        spline_degree:
            Polynomial degree of the B-spline bases.
        mlp_size:
            Hidden size of the dense head that maps pooled features to tissue scores.
        hidden_act:
            The non-linear activation function in the convolution towers and the dense head.
        batch_norm_eps:
            The epsilon used by the batch normalization layers. Defaults to 0.001 to match the upstream
            Keras `BatchNormalization` default.
        hidden_dropout:
            The dropout probability applied before the tissue projection.
        num_labels:
            Number of tissue outputs. MTSplice predicts delta-logit-PSI for the 56 GTEx tissues, so this defaults
            to 56.

    Examples:
        >>> from multimolecule import MtSpliceConfig, MtSpliceModel
        >>> # Initializing a MTSplice multimolecule/mtsplice style configuration
        >>> configuration = MtSpliceConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/mtsplice style configuration
        >>> model = MtSpliceModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "mtsplice"

    def __init__(
        self,
        vocab_size: int = 4,
        hidden_size: int = 64,
        kernel_size: int = 11,
        num_blocks: int = 8,
        block_kernel_size: int = 3,
        dilation_base: int = 2,
        acceptor_length: int = 400,
        donor_length: int = 400,
        spline_bases: int = 10,
        spline_degree: int = 3,
        mlp_size: int = 32,
        hidden_act: str = "relu",
        batch_norm_eps: float = 1e-3,
        hidden_dropout: float = 0.5,
        num_labels: int = 56,
        head: HeadConfig | None = None,
        problem_type: str | None = "regression",
        bos_token_id: int | None = None,
        eos_token_id: int | None = None,
        pad_token_id: int = 4,
        **kwargs,
    ):
        if pad_token_id != vocab_size:
            raise ValueError(
                f"MTSplice expects `pad_token_id` ({pad_token_id}) to equal `vocab_size` ({vocab_size}) so "
                "`N` padding is encoded as all-zero input channels."
            )
        super().__init__(num_labels=num_labels, pad_token_id=pad_token_id, **kwargs)
        self.bos_token_id = bos_token_id  # type: ignore[assignment]
        self.eos_token_id = eos_token_id  # type: ignore[assignment]
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.kernel_size = kernel_size
        self.num_blocks = num_blocks
        self.block_kernel_size = block_kernel_size
        self.dilation_base = dilation_base
        self.acceptor_length = acceptor_length
        self.donor_length = donor_length
        self.spline_bases = spline_bases
        self.spline_degree = spline_degree
        self.mlp_size = mlp_size
        self.hidden_act = hidden_act
        self.batch_norm_eps = batch_norm_eps
        self.hidden_dropout = hidden_dropout
        self.problem_type = problem_type
        if head is None:
            head = HeadConfig(num_labels=num_labels, hidden_size=num_labels, problem_type=problem_type)
        elif not isinstance(head, HeadConfig):
            head = HeadConfig(**head)
        self.head = head

MtSpliceForSequencePrediction

Bases: MtSplicePreTrainedModel

MTSplice with sequence-level regression loss support.

The wrapper returns the per-tissue score vector (or, when a reference and an alternative sequence are provided, the per-tissue score deltas) and applies a regression criterion when labels are supplied.

Examples:

Python Console Session
1
2
3
4
5
6
7
8
9
>>> import torch
>>> from multimolecule import MtSpliceConfig, MtSpliceForSequencePrediction
>>> config = MtSpliceConfig()
>>> model = MtSpliceForSequencePrediction(config)
>>> input_ids = torch.randint(4, (1, 800))
>>> alternative_input_ids = torch.randint(4, (1, 800))
>>> output = model(input_ids, alternative_input_ids=alternative_input_ids)
>>> output["logits"].shape
torch.Size([1, 56])
Source code in multimolecule/models/mtsplice/modeling_mtsplice.py
Python
class MtSpliceForSequencePrediction(MtSplicePreTrainedModel):
    """
    MTSplice with sequence-level regression loss support.

    The wrapper returns the per-tissue score vector (or, when a reference and an
    alternative sequence are provided, the per-tissue score deltas) and applies a
    regression criterion when labels are supplied.

    Examples:
        >>> import torch
        >>> from multimolecule import MtSpliceConfig, MtSpliceForSequencePrediction
        >>> config = MtSpliceConfig()
        >>> model = MtSpliceForSequencePrediction(config)
        >>> input_ids = torch.randint(4, (1, 800))
        >>> alternative_input_ids = torch.randint(4, (1, 800))
        >>> output = model(input_ids, alternative_input_ids=alternative_input_ids)
        >>> output["logits"].shape
        torch.Size([1, 56])
    """

    def __init__(self, config: MtSpliceConfig):
        super().__init__(config)
        self.model = MtSpliceModel(config)
        head = config.head
        if head is None:
            raise ValueError("MtSpliceForSequencePrediction requires `config.head` to be set")
        self.criterion = Criterion(head)
        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        alternative_input_ids: Tensor | NestedTensor | None = None,
        alternative_attention_mask: Tensor | None = None,
        alternative_inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Tuple[Tensor, ...] | SequencePredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            alternative_input_ids=alternative_input_ids,
            alternative_attention_mask=alternative_attention_mask,
            alternative_inputs_embeds=alternative_inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        features = outputs.delta_logits if outputs.delta_logits is not None else outputs.logits
        loss = self.criterion(features, labels) if labels is not None else None
        return SequencePredictorOutput(loss=loss, logits=features)

MtSpliceModel

Bases: MtSplicePreTrainedModel

The bare MTSplice tissue-specific backbone.

MTSplice scores a cassette exon together with its flanking introns with two parallel dilated-convolution towers (an acceptor tower over the upstream region and a donor tower over the downstream region), positionally re-weights each tower with B-spline transformations, pools, and combines the two towers into a per-tissue delta-logit-PSI vector. The backbone returns the per-tissue score vector. For variant-effect prediction, pass both a reference and an alternative sequence; the backbone then also returns the per-tissue deltas.

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> import torch
>>> from multimolecule import MtSpliceConfig, MtSpliceModel
>>> config = MtSpliceConfig()
>>> model = MtSpliceModel(config)
>>> input_ids = torch.randint(4, (1, 800))
>>> output = model(input_ids)
>>> output["logits"].shape
torch.Size([1, 56])
Source code in multimolecule/models/mtsplice/modeling_mtsplice.py
Python
class MtSpliceModel(MtSplicePreTrainedModel):
    """
    The bare MTSplice tissue-specific backbone.

    MTSplice scores a cassette exon together with its flanking introns with two
    parallel dilated-convolution towers (an acceptor tower over the upstream
    region and a donor tower over the downstream region), positionally re-weights
    each tower with B-spline transformations, pools, and combines the two towers
    into a per-tissue delta-logit-PSI vector. The backbone returns the per-tissue
    score vector. For variant-effect prediction, pass both a reference and an
    alternative sequence; the backbone then also returns the per-tissue deltas.

    Examples:
        >>> import torch
        >>> from multimolecule import MtSpliceConfig, MtSpliceModel
        >>> config = MtSpliceConfig()
        >>> model = MtSpliceModel(config)
        >>> input_ids = torch.randint(4, (1, 800))
        >>> output = model(input_ids)
        >>> output["logits"].shape
        torch.Size([1, 56])
    """

    def __init__(self, config: MtSpliceConfig):
        super().__init__(config)
        self.config = config
        self.embeddings = MtSpliceEmbedding(config)
        self.acceptor_tower = MtSpliceTower(config, config.acceptor_length)
        self.donor_tower = MtSpliceTower(config, config.donor_length)
        self.pooler = MtSplicePooler()
        self.prediction = MtSplicePredictionHead(config)
        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        alternative_input_ids: Tensor | NestedTensor | None = None,
        alternative_attention_mask: Tensor | None = None,
        alternative_inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> MtSpliceModelOutput | Tuple[Tensor, ...]:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        reference = self._score(input_ids, attention_mask, inputs_embeds)

        delta = None
        alternative = None
        has_alternative = alternative_input_ids is not None or alternative_inputs_embeds is not None
        if has_alternative:
            if alternative_input_ids is not None and alternative_inputs_embeds is not None:
                raise ValueError("You cannot specify both alternative_input_ids and alternative_inputs_embeds")
            alternative = self._score(
                alternative_input_ids,
                alternative_attention_mask,
                alternative_inputs_embeds,
            )
            delta = alternative - reference

        return MtSpliceModelOutput(
            logits=reference,
            alternative_logits=alternative,
            delta_logits=delta,
        )

    def _score(
        self,
        input_ids: Tensor | NestedTensor | None,
        attention_mask: Tensor | None,
        inputs_embeds: Tensor | NestedTensor | None,
    ) -> Tensor:
        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor
        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        acceptor, donor = self._split(embedding_output)
        if self.gradient_checkpointing and self.training:
            acceptor = self._gradient_checkpointing_func(self.acceptor_tower.__call__, acceptor)
            donor = self._gradient_checkpointing_func(self.donor_tower.__call__, donor)
        else:
            acceptor = self.acceptor_tower(acceptor)
            donor = self.donor_tower(donor)
        pooled = self.pooler(acceptor, donor)
        return self.prediction(pooled)

    def _split(self, inputs_embeds: Tensor) -> tuple[Tensor, Tensor]:
        length = inputs_embeds.size(-1)
        acceptor_length = min(self.config.acceptor_length, length)
        donor_length = min(self.config.donor_length, length)
        acceptor = inputs_embeds[..., :acceptor_length]
        donor = inputs_embeds[..., length - donor_length :]
        return acceptor, donor

MtSpliceModelOutput dataclass

Bases: ModelOutput

Base class for outputs of the MTSplice tissue-specific model.

Parameters:

Name Type Description Default

logits

`torch.FloatTensor` of shape `(batch_size, num_labels)`

The per-tissue delta-logit-PSI score vector for the (reference) input sequence, ordered as the 56 GTEx tissues (see MtSpliceConfig).

None

alternative_logits

`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*

The per-tissue score vector for the alternative sequence, returned when an alternative sequence is provided.

None

delta_logits

`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*

alternative_logits - logits, the per-tissue variant-effect deltas, returned when an alternative sequence is provided.

None
Source code in multimolecule/models/mtsplice/modeling_mtsplice.py
Python
@dataclass
class MtSpliceModelOutput(ModelOutput):
    """
    Base class for outputs of the MTSplice tissue-specific model.

    Args:
        logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
            The per-tissue delta-logit-PSI score vector for the (reference) input
            sequence, ordered as the 56 GTEx tissues (see `MtSpliceConfig`).
        alternative_logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*):
            The per-tissue score vector for the alternative sequence, returned when
            an alternative sequence is provided.
        delta_logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*):
            `alternative_logits - logits`, the per-tissue variant-effect deltas,
            returned when an alternative sequence is provided.
    """

    logits: torch.FloatTensor | None = None
    alternative_logits: torch.FloatTensor | None = None
    delta_logits: torch.FloatTensor | None = None

MtSplicePreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/mtsplice/modeling_mtsplice.py
Python
class MtSplicePreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = MtSpliceConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["MtSpliceTower"]

    def _init_weights(self, module):
        if isinstance(module, nn.Conv1d):
            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Linear):
            nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                nn.init.uniform_(module.bias, -bound, bound)
        elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm)):
            nn.init.constant_(module.weight, 1)
            nn.init.constant_(module.bias, 0)
        elif isinstance(module, MtSpliceSplineWeight):
            nn.init.zeros_(module.kernel)