跳转至

UFold

Pre-trained model for RNA secondary structure prediction using an image-like sequence representation and a U-Net.

Disclaimer

This is an UNOFFICIAL implementation of UFold: fast and accurate RNA secondary structure prediction with deep learning by Laiyi Fu, Yingxin Cao, Jie Wu, Qinke Peng, Qing Nie, and Xiaohui Xie.

The OFFICIAL repository of UFold is at uci-cbcl/UFold.

Tip

The MultiMolecule implementation is a direct PyTorch port of the original U-Net architecture and feature construction.

The team releasing UFold did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

UFold predicts RNA base-pair contact maps from single RNA sequences. It represents a sequence as a 17-channel image: 16 channels are outer products of one-hot nucleotide indicators and one channel is a hand-crafted canonical/wobble pairing score. A U-Net predicts a symmetric contact score matrix, and the original constrained post-processing routine can be enabled to enforce base-pairing constraints.

Model Specification

Num Parameters (M) FLOPs (G) MACs (G)
8.64 188.29 93.81

FLOPs and MACs are computed with multimolecule.utils for one 600 nt sequence.

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

RNA Secondary Structure Pipeline

Python
1
2
3
4
5
import multimolecule
from transformers import pipeline

predictor = pipeline("rna-secondary-structure", model="multimolecule/ufold")
output = predictor("GGGCUAUUAGCUCAGUUGGUUAGAGCGCACCCCUGAUAAGGGUGAGGUCGCUGAUUCGAAUUCAGCAUAGCUCA")

PyTorch Inference

Python
from multimolecule import RnaTokenizer, UfoldModel

tokenizer = RnaTokenizer.from_pretrained("multimolecule/ufold")
model = UfoldModel.from_pretrained("multimolecule/ufold")

sequence = "GGGCUAUUAGCUCAGUUGGUUAGAGCGCACCCCUGAUAAGGGUGAGGUCGCUGAUUCGAAUUCAGCAUAGCUCA"
inputs = tokenizer(sequence, return_tensors="pt")
output = model(**inputs)

contact_map = output.contact_map

To run the original constrained post-processing loop:

Python
output = model(**inputs, use_postprocessing=True)
contact_map = output.postprocessed_contact_map

Training Details

UFold was trained for RNA secondary structure prediction from annotated contact maps and base-pairing rules.

Training Data

  • RNAStrAlign: 30,451 unique RNAs from eight RNA families; the paper reports a random split with 24,895 training RNAs and 2,854 test RNAs after redundancy filtering.
  • bpRNA-1m: 102,318 RNAs from 2,588 families; CD-HIT was used to remove redundant sequences before splitting the data into TR0 and TS0.
  • augmented data: synthetic training examples were generated from bpRNA-new sequences by random mutation and structure prediction.
  • PDB training data: high-resolution RNA structures from bpRNA and the PDB were used for fine-tuning/evaluation experiments; test sets TS1, TS2, and TS3 were filtered at 80% sequence identity.
  • evaluation data: ArchiveII, TS0, bpRNA-new, and PDB test data were used for benchmark evaluation.

Training Procedure

  • input representation: 16 outer-product channels following the MultiMolecule tokenizer order plus one hand-crafted pairing-score channel.
  • objective: weighted binary cross entropy over base-pair contact maps.
  • optimizer: Adam.
  • training epochs: 100.
  • batch size: 1.
  • positive-class weight: 300.
  • post-processing: constrained optimization with canonical/wobble pairing rules, sparsity shrinkage, and a 0.5 threshold.

Citation

BibTeX
@article{fu2022ufold,
  author = {Fu, Laiyi and Cao, Yingxin and Wu, Jie and Peng, Qinke and Nie, Qing and Xie, Xiaohui},
  title = {UFold: fast and accurate RNA secondary structure prediction with deep learning},
  journal = {Nucleic Acids Research},
  volume = {50},
  number = {3},
  pages = {e14},
  year = {2022},
  doi = {10.1093/nar/gkab1074}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the UFold paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.ufold

RnaTokenizer

Bases: Tokenizer

Tokenizer for RNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • extended
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_T_with_U

bool

Whether to replace T with U.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import RnaTokenizer
>>> tokenizer = RnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGUNRYSWKMBDHVIX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = RnaTokenizer(replace_T_with_U=False)
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = RnaTokenizer(nmers=3)
>>> tokenizer('uagcuuauc')["input_ids"]
[1, 83, 17, 64, 49, 96, 84, 22, 2]
>>> tokenizer = RnaTokenizer(codon=True)
>>> tokenizer('uagcuuauc')["input_ids"]
[1, 83, 49, 22, 2]
>>> tokenizer('uagcuuauca')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/rna/tokenization_rna.py
Python
class RnaTokenizer(Tokenizer):
    """
    Tokenizer for RNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `extended`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_T_with_U: Whether to replace T with U.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import RnaTokenizer
        >>> tokenizer = RnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGUNRYSWKMBDHVIX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = RnaTokenizer(replace_T_with_U=False)
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = RnaTokenizer(nmers=3)
        >>> tokenizer('uagcuuauc')["input_ids"]
        [1, 83, 17, 64, 49, 96, 84, 22, 2]
        >>> tokenizer = RnaTokenizer(codon=True)
        >>> tokenizer('uagcuuauc')["input_ids"]
        [1, 83, 49, 22, 2]
        >>> tokenizer('uagcuuauca')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_T_with_U: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_T_with_U=replace_T_with_U,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_T_with_U = replace_T_with_U
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_T_with_U:
            text = text.replace("T", "U")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

UfoldConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a UfoldModel. It is used to instantiate a UFold model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the UFold uci-cbcl/UFold architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Token vocabulary size of the UFold model. Defaults to 5 for the A/C/G/U/N tokenizer vocabulary.

5

input_channels

int

Number of image-like input channels. The original UFold model uses 16 outer-product base-pair channels plus one hand-crafted pairing-score channel.

17

output_channels

int

Number of U-Net output channels. The original UFold model predicts one contact-score matrix.

1

channel_sizes

list[int] | None

U-Net channel sizes for the five original resolution levels.

None

min_size

int

Minimum padded image size used before the U-Net. The original short-sequence dataset pads to 80.

80

size_multiple

int

Spatial size multiple required by the four downsampling stages.

16

batch_norm_eps

float

Epsilon used by the BatchNorm2d layers.

1e-05

batch_norm_momentum

float

Momentum used by the BatchNorm2d layers.

0.1

threshold

float

Probability threshold for predicting base pairs during post-processing.

0.5

use_postprocessing

bool

Whether to run the UFold constrained post-processing loop in forward.

False

postprocess_iterations

int

Number of constrained post-processing iterations.

100

postprocess_lr_min

float

Learning rate for the minimization step in UFold post-processing.

0.01

postprocess_lr_max

float

Learning rate for the Lagrangian multiplier maximization step in UFold post-processing.

0.1

postprocess_rho

float

L1 sparsity coefficient used by UFold post-processing.

1.6

postprocess_with_l1

bool

Whether to apply L1 shrinkage in UFold post-processing.

True

postprocess_s

float

Logit cutoff used by UFold post-processing. Defaults to log(9), the original value.

2.1972245773362196

allow_noncanonical

bool

Whether post-processing should allow non-canonical base pairs.

False

pos_weight

float

Positive-class weight used by the original weighted binary cross-entropy training loss.

300.0

Examples:

Python Console Session
1
2
3
4
>>> from multimolecule import UfoldConfig, UfoldModel
>>> configuration = UfoldConfig()
>>> model = UfoldModel(configuration)
>>> configuration = model.config
Source code in multimolecule/models/ufold/configuration_ufold.py
Python
class UfoldConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`UfoldModel`][multimolecule.models.UfoldModel]. It is used to instantiate a UFold model according to the
    specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a
    similar configuration to that of the UFold [uci-cbcl/UFold](https://github.com/uci-cbcl/UFold) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Token vocabulary size of the UFold model. Defaults to 5 for the `A/C/G/U/N` tokenizer vocabulary.
        input_channels:
            Number of image-like input channels. The original UFold model uses 16 outer-product base-pair channels
            plus one hand-crafted pairing-score channel.
        output_channels:
            Number of U-Net output channels. The original UFold model predicts one contact-score matrix.
        channel_sizes:
            U-Net channel sizes for the five original resolution levels.
        min_size:
            Minimum padded image size used before the U-Net. The original short-sequence dataset pads to 80.
        size_multiple:
            Spatial size multiple required by the four downsampling stages.
        batch_norm_eps:
            Epsilon used by the BatchNorm2d layers.
        batch_norm_momentum:
            Momentum used by the BatchNorm2d layers.
        threshold:
            Probability threshold for predicting base pairs during post-processing.
        use_postprocessing:
            Whether to run the UFold constrained post-processing loop in `forward`.
        postprocess_iterations:
            Number of constrained post-processing iterations.
        postprocess_lr_min:
            Learning rate for the minimization step in UFold post-processing.
        postprocess_lr_max:
            Learning rate for the Lagrangian multiplier maximization step in UFold post-processing.
        postprocess_rho:
            L1 sparsity coefficient used by UFold post-processing.
        postprocess_with_l1:
            Whether to apply L1 shrinkage in UFold post-processing.
        postprocess_s:
            Logit cutoff used by UFold post-processing. Defaults to `log(9)`, the original value.
        allow_noncanonical:
            Whether post-processing should allow non-canonical base pairs.
        pos_weight:
            Positive-class weight used by the original weighted binary cross-entropy training loss.

    Examples:
        >>> from multimolecule import UfoldConfig, UfoldModel
        >>> configuration = UfoldConfig()
        >>> model = UfoldModel(configuration)
        >>> configuration = model.config
    """

    model_type = "ufold"

    def __init__(
        self,
        vocab_size: int = 5,
        input_channels: int = 17,
        output_channels: int = 1,
        channel_sizes: list[int] | None = None,
        min_size: int = 80,
        size_multiple: int = 16,
        batch_norm_eps: float = 1e-5,
        batch_norm_momentum: float = 0.1,
        threshold: float = 0.5,
        use_postprocessing: bool = False,
        postprocess_iterations: int = 100,
        postprocess_lr_min: float = 0.01,
        postprocess_lr_max: float = 0.1,
        postprocess_rho: float = 1.6,
        postprocess_with_l1: bool = True,
        postprocess_s: float = 2.1972245773362196,
        allow_noncanonical: bool = False,
        pos_weight: float = 300.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        if input_channels != 17:
            raise ValueError(f"UFold expects 17 input channels, but got {input_channels}.")
        if output_channels != 1:
            raise ValueError(f"UFold expects one output channel, but got {output_channels}.")
        if channel_sizes is None:
            channel_sizes = [32, 64, 128, 256, 512]
        if len(channel_sizes) != 5:
            raise ValueError(f"UFold expects five channel sizes, but got {len(channel_sizes)}.")
        if min_size <= 0:
            raise ValueError(f"min_size must be positive, but got {min_size}.")
        if size_multiple <= 0:
            raise ValueError(f"size_multiple must be positive, but got {size_multiple}.")

        self.vocab_size = vocab_size
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.channel_sizes = channel_sizes
        self.min_size = min_size
        self.size_multiple = size_multiple
        self.batch_norm_eps = batch_norm_eps
        self.batch_norm_momentum = batch_norm_momentum
        self.threshold = threshold
        self.use_postprocessing = use_postprocessing
        self.postprocess_iterations = postprocess_iterations
        self.postprocess_lr_min = postprocess_lr_min
        self.postprocess_lr_max = postprocess_lr_max
        self.postprocess_rho = postprocess_rho
        self.postprocess_with_l1 = postprocess_with_l1
        self.postprocess_s = postprocess_s
        self.allow_noncanonical = allow_noncanonical
        self.pos_weight = pos_weight

UfoldModel

Bases: UfoldPreTrainedModel

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import UfoldConfig, UfoldModel
>>> config = UfoldConfig(postprocess_iterations=1)
>>> model = UfoldModel(config)
>>> input_ids = torch.tensor([[0, 3, 2, 1, 0, 3]])
>>> output = model(input_ids=input_ids)
>>> output["logits"].shape
torch.Size([1, 6, 6])
>>> output["contact_map"].shape
torch.Size([1, 6, 6])
Source code in multimolecule/models/ufold/modeling_ufold.py
Python
class UfoldModel(UfoldPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import UfoldConfig, UfoldModel
        >>> config = UfoldConfig(postprocess_iterations=1)
        >>> model = UfoldModel(config)
        >>> input_ids = torch.tensor([[0, 3, 2, 1, 0, 3]])
        >>> output = model(input_ids=input_ids)
        >>> output["logits"].shape
        torch.Size([1, 6, 6])
        >>> output["contact_map"].shape
        torch.Size([1, 6, 6])
    """

    def __init__(self, config: UfoldConfig):
        super().__init__(config)
        self.gradient_checkpointing = False
        self.encoder = UfoldEncoder(
            config.input_channels,
            config.output_channels,
            config.channel_sizes,
            batch_norm_eps=config.batch_norm_eps,
            batch_norm_momentum=config.batch_norm_momentum,
        )
        self.supports_batch_process = True
        self.register_buffer(
            "pair_score",
            torch.tensor(
                [
                    [0.0, 0.0, 0.0, 2.0],
                    [0.0, 0.0, 3.0, 0.0],
                    [0.0, 3.0, 0.0, 0.8],
                    [2.0, 0.0, 0.8, 0.0],
                ]
            ),
            persistent=False,
        )
        self.register_buffer(
            "gaussian_weights",
            torch.exp(-0.5 * torch.arange(30, dtype=torch.float32).pow(2)),
            persistent=False,
        )
        self.register_buffer("pos_weight", torch.tensor([config.pos_weight]), persistent=False)
        self.post_init()

    def postprocess(self, outputs, input_ids=None, **kwargs):
        postprocessed_contact_map = outputs.get("postprocessed_contact_map")
        if postprocessed_contact_map is not None:
            return postprocessed_contact_map
        return outputs["contact_map"]

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        use_postprocessing: bool | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> UfoldModelOutput:
        if isinstance(input_ids, NestedTensor):
            input_ids, attention_mask = input_ids.tensor, input_ids.mask

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")
        if inputs_embeds is not None and isinstance(inputs_embeds, NestedTensor):
            raise TypeError("UfoldModel does not support NestedTensor inputs_embeds")

        inputs_embeds = self._prepare_inputs_embeds(input_ids, attention_mask, inputs_embeds)
        lengths = _get_lengths(inputs_embeds, attention_mask)
        inputs_embeds = self._pad_inputs_embeds(inputs_embeds, lengths)

        features = ufold_features(inputs_embeds, self.pair_score, self.gaussian_weights)
        logits_padded = self.encoder(features)
        max_length = int(lengths.max().item()) if lengths.numel() > 0 else 0
        logits = _crop_batch(logits_padded, lengths, max_length)

        should_postprocess = self.config.use_postprocessing if use_postprocessing is None else use_postprocessing
        postprocessed_contact_map = None
        if should_postprocess:
            with torch.no_grad():
                postprocessed_contact_map_padded = self._postprocess(logits_padded, inputs_embeds)
            postprocessed_contact_map = _crop_batch(postprocessed_contact_map_padded, lengths, max_length)
            contact_map = postprocessed_contact_map
        else:
            contact_map = torch.sigmoid(logits)

        loss = None
        if labels is not None:
            labels = labels.to(device=logits.device, dtype=logits.dtype)
            labels = labels[:, :max_length, :max_length]
            valid_mask = _pair_mask(lengths, max_length, logits.device)
            loss = F.binary_cross_entropy_with_logits(
                logits[valid_mask], labels[valid_mask], pos_weight=self.pos_weight
            )

        return UfoldModelOutput(
            loss=loss,
            logits=logits,
            contact_map=contact_map,
            postprocessed_contact_map=postprocessed_contact_map,
        )

    def _prepare_inputs_embeds(
        self,
        input_ids: Tensor | None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | None = None,
    ) -> Tensor:
        if inputs_embeds is not None:
            if inputs_embeds.size(-1) < 4:
                raise ValueError(f"inputs_embeds last dimension ({inputs_embeds.size(-1)}) must be at least 4.")
            one_hot = inputs_embeds[..., :4].to(dtype=self.pair_score.dtype)
        else:
            if input_ids is None:
                raise ValueError("You have to specify either input_ids or inputs_embeds")
            one_hot = F.one_hot(input_ids, num_classes=self.config.vocab_size)[..., :4].to(dtype=self.pair_score.dtype)

        if attention_mask is not None:
            one_hot = one_hot * attention_mask.unsqueeze(-1).to(one_hot.dtype)
        return one_hot

    def _pad_inputs_embeds(self, inputs_embeds: Tensor, lengths: Tensor | None = None) -> Tensor:
        if lengths is None:
            max_length = inputs_embeds.size(1)
        else:
            max_length = int(lengths.max().item()) if lengths.numel() > 0 else 0
        padded_length = _get_padded_length(
            max_length,
            min_length=self.config.min_size,
            multiple=self.config.size_multiple,
        )
        return _fit_length(inputs_embeds, padded_length)

    def _postprocess(self, logits: Tensor, one_hot: Tensor) -> Tensor:
        mask = _constraint_matrix(one_hot, allow_noncanonical=self.config.allow_noncanonical).to(dtype=logits.dtype)
        u = torch.sigmoid(2 * (logits - self.config.postprocess_s)) * logits
        a_hat = torch.sigmoid(u) * torch.sigmoid(2 * (u - self.config.postprocess_s)).detach()
        lmbd = F.relu(_contact_a(a_hat, mask).sum(dim=-1) - 1).detach()

        lr_min = self.config.postprocess_lr_min
        lr_max = self.config.postprocess_lr_max
        for _ in range(self.config.postprocess_iterations):
            violation = torch.sigmoid(2 * (_contact_a(a_hat, mask).sum(dim=-1) - 1))
            grad_a = (lmbd * violation).unsqueeze(-1).expand_as(logits) - u / 2
            grad = a_hat * mask * (grad_a + grad_a.transpose(-1, -2))
            a_hat = a_hat - lr_min * grad
            lr_min *= 0.99

            if self.config.postprocess_with_l1:
                a_hat = F.relu(torch.abs(a_hat) - self.config.postprocess_rho * lr_min)

            lmbd_grad = F.relu(_contact_a(a_hat, mask).sum(dim=-1) - 1)
            lmbd = lmbd + lr_max * lmbd_grad
            lr_max *= 0.99

        return _contact_a(a_hat, mask)

UfoldModelOutput dataclass

Bases: ModelOutput

Output type for UFold.

Source code in multimolecule/models/ufold/modeling_ufold.py
Python
@dataclass
class UfoldModelOutput(ModelOutput):
    """
    Output type for UFold.
    """

    loss: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    contact_map: torch.FloatTensor | None = None
    postprocessed_contact_map: torch.FloatTensor | None = None