跳转至

SPOT-RNA

Pre-trained model for RNA secondary structure prediction using two-dimensional deep neural networks and transfer learning.

Disclaimer

This is an UNOFFICIAL implementation of the RNA secondary structure prediction using an ensemble of two-dimensional deep neural networks and transfer learning by Jaswinder Singh, et al.

The OFFICIAL repository of SPOT-RNA is at jaswindersingh2/SPOT-RNA.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing SPOT-RNA did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

SPOT-RNA is a 2D convolutional neural network for predicting RNA secondary structure (base-pair contact maps) from single RNA sequences. It predicts both canonical (Watson-Crick and wobble) and non-canonical base pairs, including pseudoknots and other tertiary interactions.

The model uses:

  • pairwise representation: outer concatenation of canonical nucleotide features into an L x L x 8 feature matrix.
  • convolutional blocks: 2D residual convolution blocks with LayerNorm, dropout, and checkpoint-matched ReLU/ELU activations.
  • architecture paths: checkpoint-matched 2D-BLSTM or dilated-convolution paths where used by the released predictor.
  • training strategy: transfer learning from bpRNA to high-resolution PDB RNA structures.

MultiMolecule provides SPOT-RNA as a single checkpoint, multimolecule/spotrna.

Model Specification

Num Parameters (M) FLOPs (G) MACs (G)
17.46 8642.10 4302.16

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

RNA Secondary Structure Pipeline

You can use SPOT-RNA directly with the MultiMolecule secondary-structure pipeline:

Python
1
2
3
4
5
import multimolecule  # you must import multimolecule to register models
from transformers import pipeline

predictor = pipeline("rna-secondary-structure", model="multimolecule/spotrna")
output = predictor("GGGCUAUUAGCUCAGUUGGUUAGAGCGCACCCCUGAUAAGGGUGAGGUCGCUGAUUCGAAUUCAGCAUAGCUCA")

PyTorch Inference

Here is how to use this model to predict RNA secondary structure in PyTorch:

Python
import torch
from multimolecule import RnaTokenizer, SpotRnaModel

tokenizer = RnaTokenizer.from_pretrained("multimolecule/spotrna")
model = SpotRnaModel.from_pretrained("multimolecule/spotrna")

sequence = "GGGCUAUUAGCUCAGUUGGUUAGAGCGCACCCCUGAUAAGGGUGAGGUCGCUGAUUCGAAUUCAGCAUAGCUCA"
input = tokenizer(sequence, return_tensors="pt")

output = model(**input)
contact_map = output.contact_map  # (1, L, L) base-pair probability matrix

Training Details

SPOT-RNA was trained using a two-stage transfer learning approach on RNA secondary structure prediction.

Training Data

  • initial training source: bpRNA-1m (Version 1.0) with 102,348 annotated RNAs.
  • initial training filtering: CD-HIT-EST at 80% sequence identity, removal of RNAs with PDB structures, and maximum sequence length of 500 nucleotides.
  • initial training corpus: 13,419 RNAs after preprocessing.
  • initial training split: TR0 = 10,814, VL0 = 1,300, TS0 = 1,305.
  • transfer-learning source: high-resolution PDB RNAs downloaded on March 2, 2019.
  • transfer-learning filtering: resolution better than 3.5 A and CD-HIT-EST at 80% sequence identity.
  • transfer-learning corpus: 226 nonredundant RNAs after preprocessing.
  • transfer-learning split before homology filtering: TR1 = 120, VL1 = 30, TS1 = 76.
  • additional TS1 filtering: CD-HIT-EST against the training data at 80% identity, followed by BLAST-N against TR0 and TR1 with e-value cutoff 10.
  • final TS1 benchmark: 67 RNAs.
  • additional evaluation set: TS2 = 39 NMR-solved RNAs selected from 641 candidates after CD-HIT-EST filtering at 80% identity and BLAST-N filtering against TR0, TR1, and TS1.
  • use of TS2: post-training evaluation only.

Training Procedure

Preprocessing

  • input representation: one-hot L x 4 matrix following the MultiMolecule tokenizer order.
  • missing-value handling: invalid or missing residues encoded as -1 in the original TensorFlow implementation before one-hot conversion.
  • pairwise features: outer concatenation from L x 4 to L x L x 8.
  • input normalization: standardization to zero mean and unit variance using training-set statistics.
  • structure labels: extracted from PDB coordinates with DSSR.
  • reference NMR model: model 1.
  • pseudoknot and motif definitions: bpRNA definitions from the paper.
  • unknown-token handling: N tokens are excluded from the canonical four-base features before pairwise feature construction.

Pre-training

The paper states that training was run on Nvidia GTX TITAN X GPUs.

  • training split: TR0.
  • validation split: VL0.
  • optimizer: Adam.
  • regularization: 25% dropout before convolution layers and 50% dropout in hidden fully connected layers.
  • hyperparameter search over N_A: 16 to 32 residual blocks.
  • hyperparameter search over D_RES: 32 to 72 convolution channels.
  • hyperparameter search over D_BL: 128 to 256 2D-BLSTM hidden units per direction.
  • hyperparameter search over N_B: 0 to 4 fully connected blocks.
  • hyperparameter search over D_FC: 256 to 512 fully connected hidden units.
  • model selection: validation-performance model selection described in the paper.

Transfer Learning

The pretrained TR0 models were retrained on TR1 with the same architecture and optimization settings.

  • initialization: start from the TR0-trained models.
  • training split: TR1.
  • validation split: VL1.
  • frozen layers: none; all weights were updated.
  • architecture and optimization settings: same as the TS0-trained models.
  • model selection: validation-performance model selection described in the paper.
  • decision rule: a single probability threshold chosen to optimize validation performance.

Citation

BibTeX
@article{singh2019rna,
  title     = "{RNA} secondary structure prediction using an ensemble of two-dimensional deep neural networks and transfer learning",
  author    = "Singh, Jaswinder and Hanson, Jack and Paliwal, Kuldip and Zhou, Yaoqi",
  journal   = "Nature Communications",
  doi       = "10.1038/s41467-019-13395-9",
  publisher = "Springer Science and Business Media LLC",
  url       = "https://doi.org/10.1038/s41467-019-13395-9",
  volume    =  10,
  number    =  1,
  pages     = "5407",
  month     =  nov,
  year      =  2019,
  copyright = "https://creativecommons.org/licenses/by/4.0",
  language  = "en"
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the SPOT-RNA paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.spotrna

RnaTokenizer

Bases: Tokenizer

Tokenizer for RNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • extended
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_T_with_U

bool

Whether to replace T with U.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import RnaTokenizer
>>> tokenizer = RnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGUNRYSWKMBDHVIX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = RnaTokenizer(replace_T_with_U=False)
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = RnaTokenizer(nmers=3)
>>> tokenizer('uagcuuauc')["input_ids"]
[1, 83, 17, 64, 49, 96, 84, 22, 2]
>>> tokenizer = RnaTokenizer(codon=True)
>>> tokenizer('uagcuuauc')["input_ids"]
[1, 83, 49, 22, 2]
>>> tokenizer('uagcuuauca')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/rna/tokenization_rna.py
Python
class RnaTokenizer(Tokenizer):
    """
    Tokenizer for RNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `extended`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_T_with_U: Whether to replace T with U.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import RnaTokenizer
        >>> tokenizer = RnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGUNRYSWKMBDHVIX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = RnaTokenizer(replace_T_with_U=False)
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = RnaTokenizer(nmers=3)
        >>> tokenizer('uagcuuauc')["input_ids"]
        [1, 83, 17, 64, 49, 96, 84, 22, 2]
        >>> tokenizer = RnaTokenizer(codon=True)
        >>> tokenizer('uagcuuauc')["input_ids"]
        [1, 83, 49, 22, 2]
        >>> tokenizer('uagcuuauca')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_T_with_U: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_T_with_U=replace_T_with_U,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_T_with_U = replace_T_with_U
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_T_with_U:
            text = text.replace("T", "U")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

SpotRnaConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a SpotRnaModel. It is used to instantiate a SPOT-RNA model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the SPOT-RNA jaswindersingh2/SPOT-RNA architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Token vocabulary size of the SPOT-RNA model. Defaults to 5 for the A/C/G/U/N tokenizer vocabulary.

5

module_configs

list[SpotRnaModuleConfig] | None

List of internal architecture configurations. Each entry is a [SpotRnaModuleConfig] object. If None, defaults to the released SPOT-RNA architecture.

None

input_channels

int

Number of input feature channels after outer concatenation. Defaults to 8 for the canonical four-base pairwise representation.

8

hidden_act

str

The non-linear activation function in the convolutional and fully connected blocks.

'relu'

conv_dropout

float

Dropout rate in the convolutional blocks.

0.25

fc_dropout

float

Dropout rate in the fully connected blocks.

0.5

threshold

float

Probability threshold for predicting base pairs during post-processing.

0.335

Examples:

Python Console Session
1
2
3
4
>>> from multimolecule import SpotRnaConfig, SpotRnaModel
>>> configuration = SpotRnaConfig()
>>> model = SpotRnaModel(configuration)
>>> configuration = model.config
Source code in multimolecule/models/spotrna/configuration_spotrna.py
Python
class SpotRnaConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`SpotRnaModel`][multimolecule.models.SpotRnaModel]. It is used to instantiate a SPOT-RNA model according to
    the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will
    yield a similar configuration to that of the SPOT-RNA
    [jaswindersingh2/SPOT-RNA](https://github.com/jaswindersingh2/SPOT-RNA) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Token vocabulary size of the SPOT-RNA model. Defaults to 5 for the `A/C/G/U/N` tokenizer vocabulary.
        module_configs:
            List of internal architecture configurations. Each entry is a [`SpotRnaModuleConfig`] object. If None,
            defaults to the released SPOT-RNA architecture.
        input_channels:
            Number of input feature channels after outer concatenation. Defaults to 8 for the canonical four-base
            pairwise representation.
        hidden_act:
            The non-linear activation function in the convolutional and fully connected blocks.
        conv_dropout:
            Dropout rate in the convolutional blocks.
        fc_dropout:
            Dropout rate in the fully connected blocks.
        threshold:
            Probability threshold for predicting base pairs during post-processing.

    Examples:
        >>> from multimolecule import SpotRnaConfig, SpotRnaModel
        >>> configuration = SpotRnaConfig()
        >>> model = SpotRnaModel(configuration)
        >>> configuration = model.config
    """

    model_type = "spotrna"

    def __init__(
        self,
        vocab_size: int = 5,
        module_configs: list[SpotRnaModuleConfig] | None = None,
        input_channels: int = 8,
        hidden_act: str = "relu",
        conv_dropout: float = 0.25,
        fc_dropout: float = 0.5,
        threshold: float = 0.335,
        **kwargs,
    ):
        super().__init__(**kwargs)
        if input_channels % 2 != 0:
            raise ValueError(f"SpotRnaConfig.input_channels must be even, but got {input_channels}.")
        self.vocab_size = vocab_size
        if module_configs is None:
            module_configs = [
                SpotRnaModuleConfig(num_conv_blocks=16, conv_channels=48, num_fc_blocks=2, fc_hidden_size=512),
                SpotRnaModuleConfig(num_conv_blocks=20, conv_channels=64, num_fc_blocks=1, fc_hidden_size=512),
                SpotRnaModuleConfig(num_conv_blocks=30, conv_channels=64, num_fc_blocks=1, fc_hidden_size=512),
                SpotRnaModuleConfig(
                    num_conv_blocks=30,
                    conv_channels=64,
                    num_blstm_blocks=1,
                    blstm_hidden_size=200,
                    num_fc_blocks=0,
                    hidden_act="elu",
                ),
                SpotRnaModuleConfig(
                    num_conv_blocks=30,
                    conv_channels=64,
                    num_fc_blocks=1,
                    fc_hidden_size=512,
                    fc_act="elu",
                    output_act="elu",
                    use_dilation=True,
                    dilation_cycle=5,
                ),
            ]
        self.module_configs = module_configs
        self.input_channels = input_channels
        self.hidden_act = hidden_act
        self.conv_dropout = conv_dropout
        self.fc_dropout = fc_dropout
        self.threshold = threshold

SpotRnaModuleConfig

Bases: FlatDict

Configuration for one internal SPOT-RNA architecture member.

Parameters:

Name Type Description Default

num_conv_blocks

Number of convolutional blocks (N_A in the paper).

required

num_blstm_blocks

Number of 2D bidirectional LSTM blocks. Set to 0 to disable.

required

num_fc_blocks

Number of fully connected blocks. Set to 0 to disable.

required

conv_channels

Number of channels in the convolutional blocks.

required

blstm_hidden_size

Hidden size per direction in the 2D-BLSTM. Ignored if num_blstm_blocks is 0.

required

fc_hidden_size

Hidden size of the fully connected blocks. Ignored if num_fc_blocks is 0.

required

hidden_act

Activation used in the convolutional residual blocks.

required

fc_act

Optional activation used in the fully connected blocks. Falls back to hidden_act when unset.

required

output_act

Activation applied before the final normalization stage.

required

use_dilation

Whether to use dilated convolutions.

required

dilation_cycle

The cycle length for the dilation factor.

required
Source code in multimolecule/models/spotrna/configuration_spotrna.py
Python
class SpotRnaModuleConfig(FlatDict):
    r"""
    Configuration for one internal SPOT-RNA architecture member.

    Args:
        num_conv_blocks:
            Number of convolutional blocks (N_A in the paper).
        num_blstm_blocks:
            Number of 2D bidirectional LSTM blocks. Set to 0 to disable.
        num_fc_blocks:
            Number of fully connected blocks. Set to 0 to disable.
        conv_channels:
            Number of channels in the convolutional blocks.
        blstm_hidden_size:
            Hidden size per direction in the 2D-BLSTM. Ignored if num_blstm_blocks is 0.
        fc_hidden_size:
            Hidden size of the fully connected blocks. Ignored if num_fc_blocks is 0.
        hidden_act:
            Activation used in the convolutional residual blocks.
        fc_act:
            Optional activation used in the fully connected blocks. Falls back to `hidden_act` when unset.
        output_act:
            Activation applied before the final normalization stage.
        use_dilation:
            Whether to use dilated convolutions.
        dilation_cycle:
            The cycle length for the dilation factor.
    """

    num_conv_blocks: int = 20
    num_blstm_blocks: int = 0
    num_fc_blocks: int = 1
    conv_channels: int = 64
    blstm_hidden_size: int = 200
    fc_hidden_size: int = 512
    hidden_act: str = "relu"
    fc_act: str | None = None
    output_act: str = "relu"
    use_dilation: bool = False
    dilation_cycle: int = 5

SpotRnaModel

Bases: SpotRnaPreTrainedModel

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import SpotRnaConfig, SpotRnaModel
>>> config = SpotRnaConfig()
>>> model = SpotRnaModel(config)
>>> input_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1]])
>>> output = model(input_ids=input_ids)
>>> output["logits"].shape
torch.Size([1, 10, 10])
>>> output["contact_map"].shape
torch.Size([1, 10, 10])
Source code in multimolecule/models/spotrna/modeling_spotrna.py
Python
class SpotRnaModel(SpotRnaPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import SpotRnaConfig, SpotRnaModel
        >>> config = SpotRnaConfig()
        >>> model = SpotRnaModel(config)
        >>> input_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1]])
        >>> output = model(input_ids=input_ids)
        >>> output["logits"].shape
        torch.Size([1, 10, 10])
        >>> output["contact_map"].shape
        torch.Size([1, 10, 10])
    """

    def __init__(self, config: SpotRnaConfig):
        super().__init__(config)
        self.gradient_checkpointing = False

        mean, std = self._build_input_stats(device=None)
        self.register_buffer("input_mean", mean, persistent=False)
        self.register_buffer("input_std", std, persistent=False)
        self._initialized = False

        self.members = nn.ModuleList(
            [
                SpotRnaModule(
                    config,
                    SpotRnaModuleConfig(**module_config) if isinstance(module_config, dict) else module_config,
                )
                for module_config in config.module_configs
            ]
        )
        self.criterion = nn.BCEWithLogitsLoss()

        self.post_init()

    @staticmethod
    def _build_input_stats(device: torch.device | None) -> tuple[Tensor, Tensor]:
        mean = torch.tensor(
            [0.223542, 0.26099518, 0.31503478, 0.18919209, 0.223542, 0.26099518, 0.31503478, 0.18919209],
            device=device,
        ).reshape(1, 1, 1, 8)
        std = torch.tensor(
            [0.4219779, 0.44426465, 0.46934235, 0.39735729, 0.4219779, 0.44426465, 0.46934235, 0.39735729],
            device=device,
        ).reshape(1, 1, 1, 8)
        return mean, std

    def postprocess(self, outputs, input_ids=None, **kwargs):
        return outputs["contact_map"]

    def _prepare_inputs_embeds(
        self,
        input_ids: Tensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | None = None,
    ) -> Tensor:
        num_bases = self.config.input_channels // 2
        if inputs_embeds is None:
            if input_ids is None:
                raise ValueError("You have to specify either input_ids or inputs_embeds")
            canonical_ids = input_ids.clamp(min=0, max=num_bases - 1)
            inputs_embeds = F.one_hot(canonical_ids, num_classes=num_bases).to(dtype=self.input_mean.dtype)
            valid_tokens = (input_ids >= 0) & (input_ids < num_bases)
            inputs_embeds = inputs_embeds * valid_tokens.unsqueeze(-1)
        else:
            if inputs_embeds.size(-1) < num_bases:
                raise ValueError(
                    f"inputs_embeds last dimension ({inputs_embeds.size(-1)}) must be at least {num_bases}."
                )
            inputs_embeds = inputs_embeds[..., :num_bases].to(dtype=self.input_mean.dtype)

        if attention_mask is not None:
            inputs_embeds = inputs_embeds * attention_mask.unsqueeze(-1).to(inputs_embeds.dtype)
        return inputs_embeds

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> SpotRnaModelOutput:
        if isinstance(input_ids, NestedTensor):
            input_ids, attention_mask = input_ids.tensor, input_ids.mask

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if inputs_embeds is not None and isinstance(inputs_embeds, NestedTensor):
            raise TypeError("SpotRnaModel does not support NestedTensor inputs_embeds")
        inputs_embeds = self._prepare_inputs_embeds(input_ids, attention_mask, inputs_embeds)

        hidden_state = _outer_concatenate(inputs_embeds)
        if not self._initialized:
            # Workaround for transformers v5 meta-init: non-persistent buffers stay on the
            # meta device after `from_pretrained`, so re-register on the input device once
            # the real device is known.
            mean, std = self._build_input_stats(device=hidden_state.device)
            self.register_buffer("input_mean", mean, persistent=False)
            self.register_buffer("input_std", std, persistent=False)
            self._initialized = True
        hidden_state = (hidden_state - self.input_mean.to(hidden_state.dtype)) / self.input_std.to(hidden_state.dtype)

        member_logits = [member(hidden_state) for member in self.members]
        contact_map = torch.stack([torch.sigmoid(logits) for logits in member_logits]).mean(dim=0)
        logits = torch.stack(member_logits).mean(dim=0)

        loss = None
        if labels is not None:
            sequence_length = logits.size(1)
            upper_triangle_mask = torch.triu(
                torch.ones(sequence_length, sequence_length, device=logits.device, dtype=torch.bool), diagonal=2
            )
            loss = self.criterion(logits[:, upper_triangle_mask], labels[:, upper_triangle_mask].to(dtype=logits.dtype))

        return SpotRnaModelOutput(
            loss=loss,
            logits=logits,
            contact_map=contact_map,
        )

SpotRnaModelOutput dataclass

Bases: ModelOutput

Output type for SPOT-RNA model.

Parameters:

Name Type Description Default

loss

`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided

Binary cross-entropy loss for base-pair prediction.

None

logits

`torch.FloatTensor` of shape `(batch_size, seq_len, seq_len)`

Prediction logits before sigmoid.

None

contact_map

`torch.FloatTensor` of shape `(batch_size, seq_len, seq_len)`, *optional*

Base-pair probability matrix (after sigmoid).

None
Source code in multimolecule/models/spotrna/modeling_spotrna.py
Python
@dataclass
class SpotRnaModelOutput(ModelOutput):
    """
    Output type for SPOT-RNA model.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Binary cross-entropy loss for base-pair prediction.
        logits (`torch.FloatTensor` of shape `(batch_size, seq_len, seq_len)`):
            Prediction logits before sigmoid.
        contact_map (`torch.FloatTensor` of shape `(batch_size, seq_len, seq_len)`, *optional*):
            Base-pair probability matrix (after sigmoid).
    """

    loss: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    contact_map: torch.FloatTensor | None = None