Skip to content

BPfold

Pre-trained model for RNA secondary structure prediction using base pair motif energy.

Disclaimer

This is an UNOFFICIAL implementation of Deep generalizable prediction of RNA secondary structure via base pair motif energy by Heqin Zhu, Fenghe Tang, Quan Quan, Ke Chen, Peng Xiong, and S. Kevin Zhou.

The OFFICIAL repository of BPfold is at heqin-zhu/BPfold.

Tip

The MultiMolecule implementation preserves the released BPfold architecture, base-pair motif energy feature construction, and canonical/non-canonical post-processing semantics.

The team releasing BPfold did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

BPfold predicts RNA base-pair contact maps from a single RNA sequence. It augments a transformer encoder with two L x L base-pair motif energy maps computed from three-neighbor base-pair motifs. MultiMolecule exposes BPfold as a single checkpoint and stores the motif-energy lookup tables inside it.

The model uses:

  • token order: follows the MultiMolecule tokenizer.
  • unknown bases: tokenized as N and treated as U during BPfold feature construction, matching the upstream fallback; padding follows attention_mask.
  • self-attention: dynamic position bias with adjacency bias from motif-energy maps.
  • pairwise convolutions: three residual 2D convolution layers over the adjacency maps before the transformer blocks.
  • post-processing: constrained refinement for canonical pairs, plus the optional BPfold non-canonical pass and mixed canonical/non-canonical outputs.

Model Specification

Num Layers Hidden Size Num Heads Max Num Tokens Num Parameters (M) FLOPs (G) MACs (G)
12 256 8 600 47.77 87.78 42.74

FLOPs and MACs are computed with multimolecule.utils for one 600 nt sequence.

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

RNA Secondary Structure Pipeline

Python
1
2
3
4
5
import multimolecule
from transformers import pipeline

predictor = pipeline("rna-secondary-structure", model="multimolecule/bpfold")
output = predictor("GGUAAAACAGCCUGU")

PyTorch Inference

Python
1
2
3
4
5
6
7
8
from multimolecule import BpfoldModel, RnaTokenizer

tokenizer = RnaTokenizer.from_pretrained("multimolecule/bpfold")
model = BpfoldModel.from_pretrained("multimolecule/bpfold")
input = tokenizer("GGUAAAACAGCCUGU", return_tensors="pt")

output = model(**input)
contact_map = output.contact_map  # (1, L, L) base-pair probability matrix

Training Details

BPfold was trained for RNA secondary structure prediction with base-pair motif energy priors.

Training Data

  • RNAStrAlign: 37,149 RNAs from eight RNA families were filtered to remove redundant sequences and invalid secondary structures, yielding 29,647 unique RNAs. Sequences longer than 600 nt were removed for training, leaving 19,313 training RNAs.
  • bpRNA-1m: 102,318 RNAs from 2,588 families were deduplicated with CD-HIT at 80% sequence identity and split into TR0/TS0 with 12,114/1,305 RNAs.
  • evaluation data: ArchiveII contains 3,966 RNAs; Rfam12.3-14.10 contains 10,791 RNAs from 1,992 families; bpRNA-new contains 5,401 RNAs; PDB contains 116 high-resolution RNAs split into TS1/TS2/TS3.

Training Procedure

  • objective: binary cross entropy over base-pair contact maps.
  • optimizer: Adam.
  • learning rate: 5e-4.
  • training epochs: 150.
  • batch size: 48.
  • positive-class weight: 300.
  • batching: length-matching mini-batches to reduce padding.
  • sequence features: token embeddings converted to the MultiMolecule tokenizer order.
  • structural priors: two L x L energy maps from three-neighbor base-pair motifs.
  • post-processing: constrained refinement for canonical pairs, minimum loop length, non-overlapping pairs, and isolated-pair removal.

Citation

BibTeX
@article{zhu2025bpfold,
  title   = {Deep generalizable prediction of {RNA} secondary structure via base pair motif energy},
  author  = {Zhu, Heqin and Tang, Fenghe and Quan, Quan and Chen, Ke and Xiong, Peng and Zhou, S. Kevin},
  journal = {Nature Communications},
  volume  = {16},
  number  = {1},
  pages   = {5856},
  year    = {2025},
  doi     = {10.1038/s41467-025-60048-1},
  url     = {https://doi.org/10.1038/s41467-025-60048-1}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the BPfold paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.bpfold

RnaTokenizer

Bases: Tokenizer

Tokenizer for RNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • extended
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_T_with_U

bool

Whether to replace T with U.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import RnaTokenizer
>>> tokenizer = RnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGUNRYSWKMBDHVIX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = RnaTokenizer(replace_T_with_U=False)
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = RnaTokenizer(nmers=3)
>>> tokenizer('uagcuuauc')["input_ids"]
[1, 83, 17, 64, 49, 96, 84, 22, 2]
>>> tokenizer = RnaTokenizer(codon=True)
>>> tokenizer('uagcuuauc')["input_ids"]
[1, 83, 49, 22, 2]
>>> tokenizer('uagcuuauca')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/rna/tokenization_rna.py
Python
class RnaTokenizer(Tokenizer):
    """
    Tokenizer for RNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `extended`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_T_with_U: Whether to replace T with U.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import RnaTokenizer
        >>> tokenizer = RnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGUNRYSWKMBDHVIX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = RnaTokenizer(replace_T_with_U=False)
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = RnaTokenizer(nmers=3)
        >>> tokenizer('uagcuuauc')["input_ids"]
        [1, 83, 17, 64, 49, 96, 84, 22, 2]
        >>> tokenizer = RnaTokenizer(codon=True)
        >>> tokenizer('uagcuuauc')["input_ids"]
        [1, 83, 49, 22, 2]
        >>> tokenizer('uagcuuauca')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_T_with_U: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_T_with_U=replace_T_with_U,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_T_with_U = replace_T_with_U
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_T_with_U:
            text = text.replace("T", "U")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

BpfoldConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a BpfoldModel. It is used to instantiate a BPfold model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the BPfold heqin-zhu/BPfold architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Token vocabulary size of the BPfold model.

11

hidden_size

int

Dimensionality of nucleotide token embeddings and transformer hidden states.

256

num_hidden_layers

int

Number of base-pair attention transformer blocks.

12

attention_head_size

int

Hidden size per attention head.

32

intermediate_size

int

Dimensionality of the feed-forward layer inside each transformer block.

768

hidden_dropout

float

Dropout probability in transformer blocks.

0.1

positional_embedding

str

Positional bias type used by self-attention. The original checkpoint uses "dyn".

'dyn'

num_pairwise_convolutions

int

Number of convolutional layers applied to the pairwise energy map.

3

pairwise_kernel_size

int

Kernel size for pairwise energy convolutions.

3

use_squeeze_excitation

bool

Whether to use squeeze-and-excitation blocks in pairwise convolutions.

True

use_base_pair_energy

bool

Whether to use base-pair motif energy maps.

True

use_base_pair_probability

bool

Whether to use an externally provided base-pair probability map.

False

separate_outer_inner_energy

bool

Whether motif energy is represented as separate outer and inner energy maps.

True

motif_radius

int

Number of neighboring bases in base-pair motifs. The published BPfold model uses three.

3

max_length

int

Training-time maximum sequence length used by the original checkpoints.

600

threshold

float

Probability threshold for predicting base pairs during post-processing.

0.5

use_postprocessing

bool

Whether to run the constrained BPfold post-processing loop in forward.

False

postprocess_iterations

int

Number of constrained post-processing iterations.

100

postprocess_lr_min

float

Learning rate for the minimization step in post-processing.

0.01

postprocess_lr_max

float

Learning rate for the Lagrangian multiplier maximization step in post-processing.

0.1

postprocess_rho

float

L1 sparsity coefficient used by canonical post-processing.

1.6

postprocess_nc_rho

float

L1 sparsity coefficient used by non-canonical post-processing.

0.5

postprocess_with_l1

bool

Whether to apply L1 shrinkage in post-processing.

True

postprocess_s

float

Logit cutoff used by canonical post-processing.

1.5

postprocess_nc_s

float

Logit cutoff used by non-canonical post-processing.

0.5

pos_weight

float

Positive-class weight used by the original weighted binary cross-entropy training loss.

300.0

num_members

int

Number of internal checkpoint members in the released BPfold predictor.

6

Examples:

Python Console Session
1
2
3
4
>>> from multimolecule import BpfoldConfig, BpfoldModel
>>> configuration = BpfoldConfig()
>>> model = BpfoldModel(configuration)
>>> configuration = model.config
Source code in multimolecule/models/bpfold/configuration_bpfold.py
Python
class BpfoldConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`BpfoldModel`][multimolecule.models.BpfoldModel]. It is used to instantiate a BPfold model according to the
    specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a
    similar configuration to that of the BPfold [heqin-zhu/BPfold](https://github.com/heqin-zhu/BPfold) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Token vocabulary size of the BPfold model.
        hidden_size:
            Dimensionality of nucleotide token embeddings and transformer hidden states.
        num_hidden_layers:
            Number of base-pair attention transformer blocks.
        attention_head_size:
            Hidden size per attention head.
        intermediate_size:
            Dimensionality of the feed-forward layer inside each transformer block.
        hidden_dropout:
            Dropout probability in transformer blocks.
        positional_embedding:
            Positional bias type used by self-attention. The original checkpoint uses `"dyn"`.
        num_pairwise_convolutions:
            Number of convolutional layers applied to the pairwise energy map.
        pairwise_kernel_size:
            Kernel size for pairwise energy convolutions.
        use_squeeze_excitation:
            Whether to use squeeze-and-excitation blocks in pairwise convolutions.
        use_base_pair_energy:
            Whether to use base-pair motif energy maps.
        use_base_pair_probability:
            Whether to use an externally provided base-pair probability map.
        separate_outer_inner_energy:
            Whether motif energy is represented as separate outer and inner energy maps.
        motif_radius:
            Number of neighboring bases in base-pair motifs. The published BPfold model uses three.
        max_length:
            Training-time maximum sequence length used by the original checkpoints.
        threshold:
            Probability threshold for predicting base pairs during post-processing.
        use_postprocessing:
            Whether to run the constrained BPfold post-processing loop in `forward`.
        postprocess_iterations:
            Number of constrained post-processing iterations.
        postprocess_lr_min:
            Learning rate for the minimization step in post-processing.
        postprocess_lr_max:
            Learning rate for the Lagrangian multiplier maximization step in post-processing.
        postprocess_rho:
            L1 sparsity coefficient used by canonical post-processing.
        postprocess_nc_rho:
            L1 sparsity coefficient used by non-canonical post-processing.
        postprocess_with_l1:
            Whether to apply L1 shrinkage in post-processing.
        postprocess_s:
            Logit cutoff used by canonical post-processing.
        postprocess_nc_s:
            Logit cutoff used by non-canonical post-processing.
        pos_weight:
            Positive-class weight used by the original weighted binary cross-entropy training loss.
        num_members:
            Number of internal checkpoint members in the released BPfold predictor.

    Examples:
        >>> from multimolecule import BpfoldConfig, BpfoldModel
        >>> configuration = BpfoldConfig()
        >>> model = BpfoldModel(configuration)
        >>> configuration = model.config
    """

    model_type = "bpfold"

    def __init__(
        self,
        vocab_size: int = 11,
        hidden_size: int = 256,
        num_hidden_layers: int = 12,
        attention_head_size: int = 32,
        intermediate_size: int = 768,
        hidden_dropout: float = 0.1,
        positional_embedding: str = "dyn",
        num_pairwise_convolutions: int = 3,
        pairwise_kernel_size: int = 3,
        use_squeeze_excitation: bool = True,
        use_base_pair_energy: bool = True,
        use_base_pair_probability: bool = False,
        separate_outer_inner_energy: bool = True,
        motif_radius: int = 3,
        max_length: int = 600,
        threshold: float = 0.5,
        use_postprocessing: bool = False,
        postprocess_iterations: int = 100,
        postprocess_lr_min: float = 0.01,
        postprocess_lr_max: float = 0.1,
        postprocess_rho: float = 1.6,
        postprocess_nc_rho: float = 0.5,
        postprocess_with_l1: bool = True,
        postprocess_s: float = 1.5,
        postprocess_nc_s: float = 0.5,
        pos_weight: float = 300.0,
        num_members: int = 6,
        **kwargs,
    ):
        super().__init__(**kwargs)
        if hidden_size % attention_head_size != 0:
            raise ValueError(
                f"hidden_size ({hidden_size}) must be divisible by attention_head_size ({attention_head_size})."
            )
        if positional_embedding not in {"dyn", "alibi"}:
            raise ValueError(f"positional_embedding must be 'dyn' or 'alibi', but got {positional_embedding!r}.")
        if motif_radius != 3:
            raise ValueError("BPfold currently supports the published 3-neighbor motif energy table only.")
        if num_members <= 0:
            raise ValueError(f"num_members must be positive, but got {num_members}.")

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.attention_head_size = attention_head_size
        self.intermediate_size = intermediate_size
        self.hidden_dropout = hidden_dropout
        self.positional_embedding = positional_embedding
        self.num_pairwise_convolutions = num_pairwise_convolutions
        self.pairwise_kernel_size = pairwise_kernel_size
        self.use_squeeze_excitation = use_squeeze_excitation
        self.use_base_pair_energy = use_base_pair_energy
        self.use_base_pair_probability = use_base_pair_probability
        self.separate_outer_inner_energy = separate_outer_inner_energy
        self.motif_radius = motif_radius
        self.max_length = max_length
        self.threshold = threshold
        self.use_postprocessing = use_postprocessing
        self.postprocess_iterations = postprocess_iterations
        self.postprocess_lr_min = postprocess_lr_min
        self.postprocess_lr_max = postprocess_lr_max
        self.postprocess_rho = postprocess_rho
        self.postprocess_nc_rho = postprocess_nc_rho
        self.postprocess_with_l1 = postprocess_with_l1
        self.postprocess_s = postprocess_s
        self.postprocess_nc_s = postprocess_nc_s
        self.pos_weight = pos_weight
        self.num_members = num_members

BpfoldModel

Bases: BpfoldPreTrainedModel

Source code in multimolecule/models/bpfold/modeling_bpfold.py
Python
class BpfoldModel(BpfoldPreTrainedModel):
    outer_energy: Tensor
    inner_chain_energy: Tensor
    inner_hairpin_energy: Tensor
    _pair_index: Tensor

    """
    Examples:
        >>> import torch
        >>> from multimolecule import BpfoldConfig, BpfoldModel
        >>> config = BpfoldConfig(
        ...     hidden_size=8, attention_head_size=4, intermediate_size=16, num_hidden_layers=1, num_members=1
        ... )
        >>> model = BpfoldModel(config)
        >>> input_ids = torch.tensor([[1, 6, 7, 8, 9, 2]])
        >>> output = model(input_ids=input_ids)
        >>> output["logits"].shape
        torch.Size([1, 4, 4])
        >>> output["contact_map"].shape
        torch.Size([1, 4, 4])
    """

    def __init__(self, config: BpfoldConfig):
        super().__init__(config)
        self.gradient_checkpointing = False
        self.members = nn.ModuleList([BpfoldModule(config) for _ in range(config.num_members)])
        self.criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([config.pos_weight]))
        self.supports_batch_process = True

        outer_shape, inner_chain_shape, inner_hairpin_shape = _energy_table_shapes(
            num_bases=4,
            motif_radius=config.motif_radius,
        )
        self.register_buffer("outer_energy", torch.zeros(outer_shape))
        self.register_buffer("inner_chain_energy", torch.zeros(inner_chain_shape))
        self.register_buffer("inner_hairpin_energy", torch.zeros(inner_hairpin_shape))
        self.register_buffer("_pair_index", _pair_index_matrix(), persistent=False)

        self.post_init()

    def postprocess(self, outputs, input_ids=None, **kwargs):
        postprocessed_contact_map = outputs.get("postprocessed_contact_map")
        if postprocessed_contact_map is not None:
            return postprocessed_contact_map
        return outputs["contact_map"]

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        base_pair_energy: Tensor | None = None,
        base_pair_probability: Tensor | None = None,
        use_postprocessing: bool | None = None,
        return_noncanonical: bool = False,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BpfoldModelOutput:
        if isinstance(input_ids, NestedTensor):
            input_ids, attention_mask = input_ids.tensor, input_ids.mask

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")
        if inputs_embeds is not None and isinstance(inputs_embeds, NestedTensor):
            raise TypeError("BpfoldModel does not support NestedTensor inputs_embeds")

        if inputs_embeds is not None:
            inputs_embeds, attention_mask = self._prepare_inputs_embeds(inputs_embeds, attention_mask)
            network_length = inputs_embeds.size(1)
            base_one_hot = None
            valid_mask = _pair_mask(attention_mask.sum(dim=-1).long(), network_length, attention_mask.device)
            if self.config.use_base_pair_energy:
                if base_pair_energy is None:
                    raise ValueError(
                        "base_pair_energy must be provided when using inputs_embeds with use_base_pair_energy=True."
                    )
                base_pair_energy = _fit_pairwise_feature(base_pair_energy, network_length)
            else:
                base_pair_energy = None
            if base_pair_probability is not None:
                base_pair_probability = _fit_pairwise_feature(base_pair_probability, network_length)
            token_ids = None
        else:
            token_ids, attention_mask, base_indices, base_lengths, base_one_hot = self._prepare_input_ids(
                input_ids, attention_mask
            )
            network_length = token_ids.size(1)
            base_start = 1 if self.config.bos_token_id is not None else 0
            if self.config.use_base_pair_energy:
                if base_pair_energy is None:
                    base_pair_energy = self._base_pair_energy(base_indices, base_lengths, network_length, base_start)
                else:
                    base_pair_energy = _pad_pairwise_feature(base_pair_energy, base_lengths, network_length, base_start)
            else:
                base_pair_energy = None
            if base_pair_probability is not None:
                base_pair_probability = _pad_pairwise_feature(
                    base_pair_probability, base_lengths, network_length, base_start
                )

        member_logits = [
            member(
                input_ids=token_ids,
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                base_pair_energy=base_pair_energy,
                base_pair_probability=base_pair_probability,
            )
            for member in self.members
        ]
        logits_with_tokens = torch.stack(member_logits, dim=0).mean(dim=0)

        if inputs_embeds is not None:
            logits = logits_with_tokens
        else:
            logits, valid_mask, _ = self._remove_special_tokens_2d(
                logits_with_tokens.unsqueeze(-1), attention_mask, token_ids
            )
            logits = logits.squeeze(-1)
            valid_mask = valid_mask.bool()

        should_postprocess = self.config.use_postprocessing if use_postprocessing is None else use_postprocessing
        postprocessed_contact_map = None
        noncanonical_contact_map = None
        if should_postprocess:
            if base_one_hot is None:
                raise ValueError("input_ids are required for BPfold post-processing when using inputs_embeds.")
            with torch.no_grad():
                postprocessed_contact_map = self._postprocess(logits, base_one_hot, is_noncanonical=False)
                if return_noncanonical:
                    noncanonical_contact_map = self._postprocess(logits, base_one_hot, is_noncanonical=True)
            contact_map = postprocessed_contact_map
        else:
            contact_map = torch.sigmoid(logits)

        loss = None
        if labels is not None:
            labels = labels.to(device=logits.device, dtype=logits.dtype)
            max_length = logits.size(-1)
            labels = labels[:, :max_length, :max_length]
            loss = self.criterion(logits[valid_mask], labels[valid_mask])

        return BpfoldModelOutput(
            loss=loss,
            logits=logits,
            contact_map=contact_map,
            postprocessed_contact_map=postprocessed_contact_map,
            noncanonical_contact_map=noncanonical_contact_map,
        )

    def _prepare_inputs_embeds(
        self,
        inputs_embeds: Tensor,
        attention_mask: Tensor | None,
    ) -> tuple[Tensor, Tensor]:
        if inputs_embeds.size(-1) != self.config.hidden_size:
            raise ValueError(
                f"inputs_embeds last dimension ({inputs_embeds.size(-1)}) must equal hidden_size "
                f"({self.config.hidden_size})."
            )
        if attention_mask is None:
            attention_mask = torch.ones(inputs_embeds.size()[:2], dtype=torch.bool, device=inputs_embeds.device)
        else:
            attention_mask = attention_mask.to(device=inputs_embeds.device, dtype=torch.bool)
        network_length = int(attention_mask.long().sum(dim=-1).max().item()) if attention_mask.numel() > 0 else 0
        return inputs_embeds[:, :network_length], attention_mask[:, :network_length]

    def _prepare_input_ids(
        self,
        input_ids: Tensor | None,
        attention_mask: Tensor | None,
    ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
        if input_ids is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")
        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id)
        else:
            attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.bool)
        network_length = int(attention_mask.long().sum(dim=-1).max().item()) if attention_mask.numel() > 0 else 0
        token_ids = input_ids[:, :network_length].clamp(min=0, max=self.config.vocab_size - 1)
        attention_mask = attention_mask[:, :network_length]
        _, base_mask, base_token_ids = self._remove_special_tokens(
            token_ids.new_ones((*token_ids.shape, 1), dtype=torch.float32),
            attention_mask,
            token_ids,
        )
        base_mask = base_mask.bool()
        base_lengths = base_mask.sum(dim=-1).long()
        base_indices = self._base_indices(base_token_ids)
        base_one_hot = F.one_hot(base_indices, num_classes=4).to(dtype=self.outer_energy.dtype)
        base_one_hot = base_one_hot * base_mask.unsqueeze(-1).to(base_one_hot.dtype)
        return token_ids, attention_mask, base_indices, base_lengths, base_one_hot

    def _remove_special_tokens(
        self,
        output: Tensor,
        attention_mask: Tensor | None = None,
        input_ids: Tensor | None = None,
    ) -> tuple[Tensor, Tensor, Tensor]:
        if self.config.bos_token_id is not None:
            output = output[..., 1:, :]
            if attention_mask is not None:
                attention_mask = attention_mask[..., 1:]
            if input_ids is not None:
                input_ids = input_ids[..., 1:]
        if self.config.eos_token_id is not None:
            if input_ids is not None:
                eos_mask = input_ids.ne(self.config.eos_token_id).to(output.device)
                input_ids = input_ids.masked_fill(~eos_mask, self.config.pad_token_id or 0)[..., :-1]
            elif attention_mask is not None:
                last_valid_indices = attention_mask.sum(dim=-1) - 1
                seq_length = attention_mask.size(-1)
                eos_mask = torch.arange(seq_length, device=output.device) != last_valid_indices.unsqueeze(1)
            else:
                raise ValueError("Unable to remove EOS tokens because input_ids and attention_mask are both None")
            output = (output * eos_mask.unsqueeze(-1))[..., :-1, :]
            if attention_mask is not None:
                attention_mask = (attention_mask * eos_mask)[..., :-1]
        if attention_mask is not None:
            output = output * attention_mask.unsqueeze(-1)
        return output, attention_mask, input_ids

    def _remove_special_tokens_2d(
        self,
        output: Tensor,
        attention_mask: Tensor | None = None,
        input_ids: Tensor | None = None,
    ) -> tuple[Tensor, Tensor, Tensor]:
        if self.config.bos_token_id is not None:
            output = output[..., 1:, 1:, :]
            if attention_mask is not None:
                attention_mask = attention_mask[..., 1:]
            if input_ids is not None:
                input_ids = input_ids[..., 1:]
        if self.config.eos_token_id is not None:
            if input_ids is not None:
                eos_mask = input_ids.ne(self.config.eos_token_id).to(output.device)
                input_ids = input_ids.masked_fill(~eos_mask, self.config.pad_token_id or 0)[..., :-1]
            elif attention_mask is not None:
                last_valid_indices = attention_mask.sum(dim=-1) - 1
                seq_length = attention_mask.size(-1)
                eos_mask = torch.arange(seq_length, device=output.device) != last_valid_indices.unsqueeze(1)
            else:
                raise ValueError("Unable to remove EOS tokens because input_ids and attention_mask are both None")
            if attention_mask is not None:
                attention_mask = (attention_mask * eos_mask)[..., :-1]
            eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
            output = (output * eos_mask.unsqueeze(-1))[..., :-1, :-1, :]
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
            output = output * attention_mask.unsqueeze(-1)
        return output, attention_mask, input_ids

    def _base_indices(self, input_ids: Tensor) -> Tensor:
        if self.config.null_token_id is None:
            raise ValueError("BpfoldModel requires null_token_id to infer nucleotide token ids from input_ids.")
        base_offset = self.config.null_token_id + 1
        base_indices = input_ids - base_offset
        return torch.where(
            (base_indices >= 0) & (base_indices < 4),
            base_indices,
            torch.full_like(base_indices, 3),
        )

    def _base_pair_energy(
        self,
        base_indices: Tensor,
        base_lengths: Tensor,
        target_length: int,
        base_start: int,
    ) -> Tensor:
        batch_size = base_indices.size(0)
        num_channels = 2 if self.config.separate_outer_inner_energy else 1
        energy = self.outer_energy.new_zeros((batch_size, num_channels, target_length, target_length))
        pair_index = _pair_index_matrix().to(device=base_indices.device)

        for batch_index in range(batch_size):
            length = int(base_lengths[batch_index].item())
            if length <= 0:
                continue
            seq = base_indices[batch_index, :length]
            seq_energy = _build_energy_map_from_tokens(
                seq,
                pair_index,
                self.outer_energy,
                self.inner_chain_energy,
                self.inner_hairpin_energy,
                num_bases=4,
                motif_radius=self.config.motif_radius,
                separate_outer_inner=self.config.separate_outer_inner_energy,
            )
            base_end = base_start + length
            energy[batch_index, :, base_start:base_end, base_start:base_end] = seq_energy
        return energy

    def _postprocess(self, logits: Tensor, base_one_hot: Tensor, is_noncanonical: bool = False) -> Tensor:
        if is_noncanonical:
            rho = self.config.postprocess_nc_rho
            threshold_logit = self.config.postprocess_nc_s
        else:
            rho = self.config.postprocess_rho
            threshold_logit = self.config.postprocess_s

        mask = _constraint_matrix(base_one_hot, is_noncanonical=is_noncanonical).to(dtype=logits.dtype)
        u = torch.sigmoid(2 * (logits - threshold_logit)) * logits
        a_hat = torch.sigmoid(u) * torch.sigmoid(2 * (u - threshold_logit)).detach()
        lmbd = F.relu(_contact_a(a_hat, mask).sum(dim=-1) - 1).detach()

        lr_min = self.config.postprocess_lr_min
        lr_max = self.config.postprocess_lr_max
        for _ in range(self.config.postprocess_iterations):
            violation = torch.sigmoid(2 * (_contact_a(a_hat, mask).sum(dim=-1) - 1))
            grad_a = (lmbd * violation).unsqueeze(-1).expand_as(logits) - u / 2
            grad = a_hat * mask * (grad_a + grad_a.transpose(-1, -2))
            a_hat = a_hat - lr_min * grad
            lr_min *= 0.99

            if self.config.postprocess_with_l1:
                a_hat = F.relu(torch.abs(a_hat) - rho * lr_min)

            lmbd_grad = F.relu(_contact_a(a_hat, mask).sum(dim=-1) - 1)
            lmbd = lmbd + lr_max * lmbd_grad
            lr_max *= 0.99

        contact_map = _contact_a(a_hat, mask)
        contact_map = (contact_map > self.config.threshold).to(dtype=logits.dtype)
        if is_noncanonical:
            contact_map = contact_map * _noncanonical_matrix(base_one_hot).to(dtype=logits.dtype)
        return contact_map