Skip to content

UFold

Pre-trained model for RNA secondary structure prediction using an image-like sequence representation and a U-Net.

Disclaimer

This is an UNOFFICIAL implementation of UFold: fast and accurate RNA secondary structure prediction with deep learning by Laiyi Fu, Yingxin Cao, et al.

The OFFICIAL repository of UFold is at uci-cbcl/UFold.

Tip

The MultiMolecule implementation is a direct PyTorch port of the original U-Net architecture and feature construction.

The team releasing UFold did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

UFold predicts RNA base-pair contact maps from single RNA sequences. It represents a sequence as a 17-channel image: 16 channels are outer products of one-hot nucleotide indicators and one channel is a hand-crafted canonical/wobble pairing score. A U-Net predicts a symmetric contact score matrix, and the original constrained post-processing routine can be enabled to enforce base-pairing constraints.

Model Specification

Num Parameters (M) FLOPs (G) MACs (G)
8.64 188.29 93.81

FLOPs and MACs are computed with multimolecule.utils for one 600 nt sequence.

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

RNA Secondary Structure Pipeline

Python
1
2
3
4
5
import multimolecule
from transformers import pipeline

predictor = pipeline("rna-secondary-structure", model="multimolecule/ufold")
output = predictor("GGGCUAUUAGCUCAGUUGGUUAGAGCGCACCCCUGAUAAGGGUGAGGUCGCUGAUUCGAAUUCAGCAUAGCUCA")

PyTorch Inference

Python
from multimolecule import RnaTokenizer, UfoldModel

tokenizer = RnaTokenizer.from_pretrained("multimolecule/ufold")
model = UfoldModel.from_pretrained("multimolecule/ufold")

sequence = "GGGCUAUUAGCUCAGUUGGUUAGAGCGCACCCCUGAUAAGGGUGAGGUCGCUGAUUCGAAUUCAGCAUAGCUCA"
inputs = tokenizer(sequence, return_tensors="pt")
output = model(**inputs)

contact_map = output.contact_map

To run the original constrained post-processing loop:

Python
output = model(**inputs, use_postprocessing=True)
contact_map = output.postprocessed_contact_map

Training Details

UFold was trained for RNA secondary structure prediction from annotated contact maps and base-pairing rules.

Training Data

  • RNAStrAlign: 30,451 unique RNAs from eight RNA families; the paper reports a random split with 24,895 training RNAs and 2,854 test RNAs after redundancy filtering.
  • bpRNA-1m: 102,318 RNAs from 2,588 families; CD-HIT was used to remove redundant sequences before splitting the data into TR0 and TS0.
  • augmented data: synthetic training examples were generated from bpRNA-new sequences by random mutation and structure prediction.
  • PDB training data: high-resolution RNA structures from bpRNA and the PDB were used for fine-tuning/evaluation experiments; test sets TS1, TS2, and TS3 were filtered at 80% sequence identity.
  • evaluation data: ArchiveII, TS0, bpRNA-new, and PDB test data were used for benchmark evaluation.

Training Procedure

  • input representation: 16 outer-product channels following the MultiMolecule tokenizer order plus one hand-crafted pairing-score channel.
  • objective: weighted binary cross entropy over base-pair contact maps.
  • optimizer: Adam.
  • training epochs: 100.
  • batch size: 1.
  • positive-class weight: 300.
  • post-processing: constrained optimization with canonical/wobble pairing rules, sparsity shrinkage, and a 0.5 threshold.

Citation

BibTeX
@article{fu2022ufold,
  author = {Fu, Laiyi and Cao, Yingxin and Wu, Jie and Peng, Qinke and Nie, Qing and Xie, Xiaohui},
  title = {UFold: fast and accurate RNA secondary structure prediction with deep learning},
  journal = {Nucleic Acids Research},
  volume = {50},
  number = {3},
  pages = {e14},
  year = {2022},
  doi = {10.1093/nar/gkab1074}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If MultiMolecule supports your research, please cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the UFold paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

API Reference

UfoldConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a UfoldModel. It is used to instantiate a UFold model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the UFold uci-cbcl/UFold architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Token vocabulary size of the UFold model. Defaults to 5 for the A/C/G/U/N tokenizer vocabulary.

5

input_channels

int

Number of image-like input channels. The original UFold model uses 16 outer-product base-pair channels plus one hand-crafted pairing-score channel.

17

output_channels

int

Number of U-Net output channels. The original UFold model predicts one contact-score matrix.

1

channel_sizes

list[int] | None

U-Net channel sizes for the five original resolution levels.

None

min_size

int

Minimum padded image size used before the U-Net. The original short-sequence dataset pads to 80.

80

size_multiple

int

Spatial size multiple required by the four downsampling stages.

16

batch_norm_eps

float

Epsilon used by the BatchNorm2d layers.

1e-05

batch_norm_momentum

float

Momentum used by the BatchNorm2d layers.

0.1

threshold

float

Probability threshold for predicting base pairs during post-processing.

0.5

use_postprocessing

bool

Whether to run the UFold constrained post-processing loop in forward.

False

postprocess_iterations

int

Number of constrained post-processing iterations.

100

postprocess_lr_min

float

Learning rate for the minimization step in UFold post-processing.

0.01

postprocess_lr_max

float

Learning rate for the Lagrangian multiplier maximization step in UFold post-processing.

0.1

postprocess_rho

float

L1 sparsity coefficient used by UFold post-processing.

1.6

postprocess_with_l1

bool

Whether to apply L1 shrinkage in UFold post-processing.

True

postprocess_s

float

Logit cutoff used by UFold post-processing. Defaults to log(9), the original value.

2.1972245773362196

allow_noncanonical

bool

Whether post-processing should allow non-canonical base pairs.

False

pos_weight

float

Positive-class weight used by the original weighted binary cross-entropy training loss.

300.0

Examples:

Python Console Session
1
2
3
4
>>> from multimolecule import UfoldConfig, UfoldModel
>>> configuration = UfoldConfig()
>>> model = UfoldModel(configuration)
>>> configuration = model.config
Source code in multimolecule/models/ufold/configuration_ufold.py
Python
class UfoldConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`UfoldModel`][multimolecule.models.UfoldModel]. It is used to instantiate a UFold model according to the
    specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a
    similar configuration to that of the UFold [uci-cbcl/UFold](https://github.com/uci-cbcl/UFold) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Token vocabulary size of the UFold model. Defaults to 5 for the `A/C/G/U/N` tokenizer vocabulary.
        input_channels:
            Number of image-like input channels. The original UFold model uses 16 outer-product base-pair channels
            plus one hand-crafted pairing-score channel.
        output_channels:
            Number of U-Net output channels. The original UFold model predicts one contact-score matrix.
        channel_sizes:
            U-Net channel sizes for the five original resolution levels.
        min_size:
            Minimum padded image size used before the U-Net. The original short-sequence dataset pads to 80.
        size_multiple:
            Spatial size multiple required by the four downsampling stages.
        batch_norm_eps:
            Epsilon used by the BatchNorm2d layers.
        batch_norm_momentum:
            Momentum used by the BatchNorm2d layers.
        threshold:
            Probability threshold for predicting base pairs during post-processing.
        use_postprocessing:
            Whether to run the UFold constrained post-processing loop in `forward`.
        postprocess_iterations:
            Number of constrained post-processing iterations.
        postprocess_lr_min:
            Learning rate for the minimization step in UFold post-processing.
        postprocess_lr_max:
            Learning rate for the Lagrangian multiplier maximization step in UFold post-processing.
        postprocess_rho:
            L1 sparsity coefficient used by UFold post-processing.
        postprocess_with_l1:
            Whether to apply L1 shrinkage in UFold post-processing.
        postprocess_s:
            Logit cutoff used by UFold post-processing. Defaults to `log(9)`, the original value.
        allow_noncanonical:
            Whether post-processing should allow non-canonical base pairs.
        pos_weight:
            Positive-class weight used by the original weighted binary cross-entropy training loss.

    Examples:
        >>> from multimolecule import UfoldConfig, UfoldModel
        >>> configuration = UfoldConfig()
        >>> model = UfoldModel(configuration)
        >>> configuration = model.config
    """

    model_type = "ufold"

    def __init__(
        self,
        vocab_size: int = 5,
        input_channels: int = 17,
        output_channels: int = 1,
        channel_sizes: list[int] | None = None,
        min_size: int = 80,
        size_multiple: int = 16,
        batch_norm_eps: float = 1e-5,
        batch_norm_momentum: float = 0.1,
        threshold: float = 0.5,
        use_postprocessing: bool = False,
        postprocess_iterations: int = 100,
        postprocess_lr_min: float = 0.01,
        postprocess_lr_max: float = 0.1,
        postprocess_rho: float = 1.6,
        postprocess_with_l1: bool = True,
        postprocess_s: float = 2.1972245773362196,
        allow_noncanonical: bool = False,
        pos_weight: float = 300.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        if input_channels != 17:
            raise ValueError(f"UFold expects 17 input channels, but got {input_channels}.")
        if output_channels != 1:
            raise ValueError(f"UFold expects one output channel, but got {output_channels}.")
        if channel_sizes is None:
            channel_sizes = [32, 64, 128, 256, 512]
        if len(channel_sizes) != 5:
            raise ValueError(f"UFold expects five channel sizes, but got {len(channel_sizes)}.")
        if min_size <= 0:
            raise ValueError(f"min_size must be positive, but got {min_size}.")
        if size_multiple <= 0:
            raise ValueError(f"size_multiple must be positive, but got {size_multiple}.")

        self.vocab_size = vocab_size
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.channel_sizes = channel_sizes
        self.min_size = min_size
        self.size_multiple = size_multiple
        self.batch_norm_eps = batch_norm_eps
        self.batch_norm_momentum = batch_norm_momentum
        self.threshold = threshold
        self.use_postprocessing = use_postprocessing
        self.postprocess_iterations = postprocess_iterations
        self.postprocess_lr_min = postprocess_lr_min
        self.postprocess_lr_max = postprocess_lr_max
        self.postprocess_rho = postprocess_rho
        self.postprocess_with_l1 = postprocess_with_l1
        self.postprocess_s = postprocess_s
        self.allow_noncanonical = allow_noncanonical
        self.pos_weight = pos_weight

UfoldModel

Bases: UfoldPreTrainedModel

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import UfoldConfig, UfoldModel
>>> config = UfoldConfig(postprocess_iterations=1)
>>> model = UfoldModel(config)
>>> input_ids = torch.tensor([[0, 3, 2, 1, 0, 3]])
>>> output = model(input_ids=input_ids)
>>> output["logits"].shape
torch.Size([1, 6, 6])
>>> output["contact_map"].shape
torch.Size([1, 6, 6])
Source code in multimolecule/models/ufold/modeling_ufold.py
Python
class UfoldModel(UfoldPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import UfoldConfig, UfoldModel
        >>> config = UfoldConfig(postprocess_iterations=1)
        >>> model = UfoldModel(config)
        >>> input_ids = torch.tensor([[0, 3, 2, 1, 0, 3]])
        >>> output = model(input_ids=input_ids)
        >>> output["logits"].shape
        torch.Size([1, 6, 6])
        >>> output["contact_map"].shape
        torch.Size([1, 6, 6])
    """

    pair_score: Tensor
    gaussian_weights: Tensor
    pos_weight: Tensor

    def __init__(self, config: UfoldConfig):
        super().__init__(config)
        self.gradient_checkpointing = False
        self.encoder = UfoldEncoder(
            config.input_channels,
            config.output_channels,
            config.channel_sizes,
            batch_norm_eps=config.batch_norm_eps,
            batch_norm_momentum=config.batch_norm_momentum,
        )
        self.supports_batch_process = True
        self.register_buffer(
            "pair_score",
            torch.tensor(
                [
                    [0.0, 0.0, 0.0, 2.0],
                    [0.0, 0.0, 3.0, 0.0],
                    [0.0, 3.0, 0.0, 0.8],
                    [2.0, 0.0, 0.8, 0.0],
                ]
            ),
            persistent=False,
        )
        self.register_buffer(
            "gaussian_weights",
            torch.exp(-0.5 * torch.arange(30, dtype=torch.float32).pow(2)),
            persistent=False,
        )
        self.register_buffer("pos_weight", torch.tensor([config.pos_weight]), persistent=False)
        self.post_init()

    def _reset_prior_buffers(self) -> None:
        self.pair_score = torch.tensor(
            [
                [0.0, 0.0, 0.0, 2.0],
                [0.0, 0.0, 3.0, 0.0],
                [0.0, 3.0, 0.0, 0.8],
                [2.0, 0.0, 0.8, 0.0],
            ],
            device=self.pair_score.device,
            dtype=self.pair_score.dtype,
        )
        self.gaussian_weights = torch.exp(
            -0.5 * torch.arange(30, device=self.gaussian_weights.device, dtype=self.gaussian_weights.dtype).pow(2)
        )
        self.pos_weight = torch.tensor(
            [self.config.pos_weight],
            device=self.pos_weight.device,
            dtype=self.pos_weight.dtype,
        )

    def postprocess(self, outputs, input_ids=None, **kwargs):
        postprocessed_contact_map = outputs.get("postprocessed_contact_map")
        if postprocessed_contact_map is not None:
            return postprocessed_contact_map
        return outputs["contact_map"]

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        use_postprocessing: bool | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> UfoldModelOutput:
        if isinstance(input_ids, NestedTensor):
            input_ids, attention_mask = input_ids.tensor, input_ids.mask

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")
        if inputs_embeds is not None and isinstance(inputs_embeds, NestedTensor):
            raise TypeError("UfoldModel does not support NestedTensor inputs_embeds")

        inputs_embeds = self._prepare_inputs_embeds(input_ids, attention_mask, inputs_embeds)
        lengths = _get_lengths(inputs_embeds, attention_mask)
        inputs_embeds = self._pad_inputs_embeds(inputs_embeds, lengths)

        features = ufold_features(inputs_embeds, self.pair_score, self.gaussian_weights)
        logits_padded = self.encoder(features)
        max_length = int(lengths.max().item()) if lengths.numel() > 0 else 0
        logits = _crop_batch(logits_padded, lengths, max_length)

        should_postprocess = self.config.use_postprocessing if use_postprocessing is None else use_postprocessing
        postprocessed_contact_map = None
        if should_postprocess:
            with torch.no_grad():
                postprocessed_contact_map_padded = self._postprocess(logits_padded, inputs_embeds)
            postprocessed_contact_map = _crop_batch(postprocessed_contact_map_padded, lengths, max_length)
            contact_map = postprocessed_contact_map
        else:
            contact_map = torch.sigmoid(logits)

        loss = None
        if labels is not None:
            labels = labels.to(device=logits.device, dtype=logits.dtype)
            labels = labels[:, :max_length, :max_length]
            valid_mask = _pair_mask(lengths, max_length, logits.device)
            loss = F.binary_cross_entropy_with_logits(
                logits[valid_mask], labels[valid_mask], pos_weight=self.pos_weight
            )

        return UfoldModelOutput(
            loss=loss,
            logits=logits,
            contact_map=contact_map,
            postprocessed_contact_map=postprocessed_contact_map,
        )

    def _prepare_inputs_embeds(
        self,
        input_ids: Tensor | None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | None = None,
    ) -> Tensor:
        if inputs_embeds is not None:
            if inputs_embeds.size(-1) < 4:
                raise ValueError(f"inputs_embeds last dimension ({inputs_embeds.size(-1)}) must be at least 4.")
            one_hot = inputs_embeds[..., :4].to(dtype=self.pair_score.dtype)
        else:
            if input_ids is None:
                raise ValueError("You have to specify either input_ids or inputs_embeds")
            one_hot = F.one_hot(input_ids, num_classes=self.config.vocab_size)[..., :4].to(dtype=self.pair_score.dtype)

        if attention_mask is not None:
            one_hot = one_hot * attention_mask.unsqueeze(-1).to(one_hot.dtype)
        return one_hot

    def _pad_inputs_embeds(self, inputs_embeds: Tensor, lengths: Tensor | None = None) -> Tensor:
        if lengths is None:
            max_length = inputs_embeds.size(1)
        else:
            max_length = int(lengths.max().item()) if lengths.numel() > 0 else 0
        padded_length = _get_padded_length(
            max_length,
            min_length=self.config.min_size,
            multiple=self.config.size_multiple,
        )
        return _fit_length(inputs_embeds, padded_length)

    def _postprocess(self, logits: Tensor, one_hot: Tensor) -> Tensor:
        mask = _constraint_matrix(one_hot, allow_noncanonical=self.config.allow_noncanonical).to(dtype=logits.dtype)
        u = torch.sigmoid(2 * (logits - self.config.postprocess_s)) * logits
        a_hat = torch.sigmoid(u) * torch.sigmoid(2 * (u - self.config.postprocess_s)).detach()
        lmbd = F.relu(_contact_a(a_hat, mask).sum(dim=-1) - 1).detach()

        lr_min = self.config.postprocess_lr_min
        lr_max = self.config.postprocess_lr_max
        for _ in range(self.config.postprocess_iterations):
            violation = torch.sigmoid(2 * (_contact_a(a_hat, mask).sum(dim=-1) - 1))
            grad_a = (lmbd * violation).unsqueeze(-1).expand_as(logits) - u / 2
            grad = a_hat * mask * (grad_a + grad_a.transpose(-1, -2))
            a_hat = a_hat - lr_min * grad
            lr_min *= 0.99

            if self.config.postprocess_with_l1:
                a_hat = F.relu(torch.abs(a_hat) - self.config.postprocess_rho * lr_min)

            lmbd_grad = F.relu(_contact_a(a_hat, mask).sum(dim=-1) - 1)
            lmbd = lmbd + lr_max * lmbd_grad
            lr_max *= 0.99

        return _contact_a(a_hat, mask)

UfoldModelOutput dataclass

Bases: ModelOutput

Output type for UFold.

Parameters:

Name Type Description Default

loss

`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided

Binary cross-entropy loss (with positive-class weighting) for base-pair prediction.

None

logits

`torch.FloatTensor` of shape `(batch_size, seq_len, seq_len)`

Raw pre-sigmoid prediction scores. These are NOT probabilities; apply torch.sigmoid to obtain per-pair probabilities.

None

contact_map

`torch.FloatTensor` of shape `(batch_size, seq_len, seq_len)`, *optional*

Post-sigmoid base-pair probability matrix. When use_postprocessing=True this is the result of the constrained post-processing loop (a binary 0/1 map); otherwise it equals torch.sigmoid(logits).

None

postprocessed_contact_map

`torch.FloatTensor` of shape `(batch_size, seq_len, seq_len)`, *optional*

Binary contact map produced by the constrained UFold post-processing loop. Only present when use_postprocessing=True; identical to contact_map in that case.

None
Source code in multimolecule/models/ufold/modeling_ufold.py
Python
@dataclass
class UfoldModelOutput(ModelOutput):
    """
    Output type for UFold.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Binary cross-entropy loss (with positive-class weighting) for base-pair prediction.
        logits (`torch.FloatTensor` of shape `(batch_size, seq_len, seq_len)`):
            Raw pre-sigmoid prediction scores. These are NOT probabilities; apply `torch.sigmoid` to obtain
            per-pair probabilities.
        contact_map (`torch.FloatTensor` of shape `(batch_size, seq_len, seq_len)`, *optional*):
            Post-sigmoid base-pair probability matrix. When `use_postprocessing=True` this is the result of
            the constrained post-processing loop (a binary 0/1 map); otherwise it equals
            `torch.sigmoid(logits)`.
        postprocessed_contact_map (`torch.FloatTensor` of shape `(batch_size, seq_len, seq_len)`, *optional*):
            Binary contact map produced by the constrained UFold post-processing loop. Only present when
            `use_postprocessing=True`; identical to `contact_map` in that case.
    """

    loss: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    contact_map: torch.FloatTensor | None = None
    postprocessed_contact_map: torch.FloatTensor | None = None