Skip to content

RNA Secondary Structure

multimolecule.pipelines.rna_secondary_structure

RnaSecondaryStructurePipeline

Bases: Pipeline

RNA secondary structure prediction pipeline using any ModelWithSecondaryStructureHead.

Examples:

Python
1
2
3
4
5
6
7
>>> import multimolecule
>>> from transformers import pipeline

>>> predictor = pipeline("rna-secondary-structure")
>>> output = predictor("UAGCUUAUCAGACUGAUGUUG")
>>> output["secondary_structure"]
'.....................'

Learn more about the basics of using a pipeline in the pipeline tutorial

This secondary structure prediction pipeline can currently be loaded from [pipeline] using the following task identifier: "rna-secondary-structure".

The models that this pipeline can use are models that have been trained with a RNA secondary structure prediction objective, which includes the bi-directional models in the library. See the up-to-date list of available models on huggingface.co/models.

Source code in multimolecule/pipelines/rna_secondary_structure.py
Python
class RnaSecondaryStructurePipeline(Pipeline):
    """
    RNA secondary structure prediction pipeline using any `ModelWithSecondaryStructureHead`.

    Examples:

    ```python
    >>> import multimolecule
    >>> from transformers import pipeline

    >>> predictor = pipeline("rna-secondary-structure")
    >>> output = predictor("UAGCUUAUCAGACUGAUGUUG")
    >>> output["secondary_structure"]
    '.....................'

    ```

    Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)

    This secondary structure prediction pipeline can currently be loaded from [`pipeline`] using the following task
    identifier: `"rna-secondary-structure"`.

    The models that this pipeline can use are models that have been trained with a RNA secondary structure prediction
    objective, which includes the bi-directional models in the library. See the up-to-date list of available models on
    [huggingface.co/models](https://huggingface.co/models?filter=rna-secondary-structure).
    """

    threshold: float = 0.5
    output_contact_map: bool = False

    def preprocess(
        self, inputs, return_tensors=None, tokenizer_kwargs=None, **preprocess_parameters
    ) -> Dict[str, GenericTensor]:
        if return_tensors is None:
            return_tensors = self.framework
        if tokenizer_kwargs is None:
            tokenizer_kwargs = {}

        model_inputs = self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs)
        return model_inputs

    def _forward(self, model_inputs):
        model_outputs = self.model(**model_inputs)
        model_outputs["input_ids"] = model_inputs["input_ids"]

        if len(model_inputs["input_ids"]) > 1 and getattr(self.model, "supports_batch_process", False):
            warn(
                "The pipeline received a batch of sequences as input.\n"
                "Most RNA Secondary Structure models are designed and trained with a single sequence.\n"
                "The results may be less reliable in batch processing.\n",
                RuntimeWarning,
            )
        return model_outputs

    def _postprocess(self, contact_map: GenericTensor) -> GenericTensor:
        contact_map = contact_map.squeeze(-1)
        contact_map.fill_diagonal_(torch.finfo(contact_map.dtype).min)
        contact_map = contact_map.sigmoid()
        return contact_map

    def postprocess(self, model_outputs, threshold: float | None = None, output_contact_map: bool | None = None):
        if threshold is None:
            threshold = self.threshold
        if output_contact_map is None:
            output_contact_map = self.output_contact_map

        input_ids = model_outputs["input_ids"]
        if hasattr(self.model, "postprocess"):
            outputs = self.model.postprocess(model_outputs)
            postprocessed = True
        else:
            if "logits_ss" in model_outputs:
                outputs = model_outputs["logits_ss"]
            elif "logits" in model_outputs:
                outputs = model_outputs["logits"]
            else:
                raise PipelineException(
                    "rna-secondary-structure", self.model.base_model_prefix, "Unable to find logits in model outputs."
                )
            postprocessed = False

        if len(input_ids) == 1:
            sequence = self.tokenizer.decode(input_ids.squeeze(0), skip_special_tokens=True).replace(" ", "")
            contact_map = outputs.squeeze(0)
            if not postprocessed:
                contact_map = self._postprocess(contact_map)
            dot_bracket = contact_map_to_dot_bracket(contact_map > threshold, unsafe=True)
            ret = {"sequence": sequence, "secondary_structure": dot_bracket}
            if output_contact_map:
                ret["contact_map"] = contact_map.numpy()
            return ret

        sequences = [i.replace(" ", "") for i in self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
        dot_brackets = []
        if self.output_contact_map:
            contact_maps = []
        for sequence, contact_map in zip(sequences, outputs):
            contact_map = contact_map[: len(sequence), : len(sequence)]
            if not postprocessed:
                contact_map = self._postprocess(contact_map)
            dot_brackets.append(contact_map_to_dot_bracket(contact_map > threshold, unsafe=True))
            if self.output_contact_map:
                contact_maps.append(contact_map.numpy())
        ret = {"sequence": sequences, "secondary_structure": dot_brackets}
        if self.output_contact_map:
            ret["contact_map"] = contact_maps
        return ret

    def _sanitize_parameters(
        self, threshold: float | None = None, output_contact_map: bool | None = None, tokenizer_kwargs=None
    ):
        preprocess_params = {}

        if tokenizer_kwargs is not None:
            preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs

        postprocess_params = {}

        if threshold is not None:
            postprocess_params["threshold"] = threshold
            if threshold >= 1:
                raise PipelineException(
                    "rna-secondary-structure",
                    self.model.base_model_prefix,
                    f"Threshold must be less than 1, but got {threshold}.",
                )
            if threshold <= 0:
                raise PipelineException(
                    "rna-secondary-structure",
                    self.model.base_model_prefix,
                    f"Threshold must be greater than 0, but got {threshold}.",
                )
        if output_contact_map is not None:
            if not isinstance(output_contact_map, bool):
                raise PipelineException(
                    "rna-secondary-structure",
                    self.model.base_model_prefix,
                    f"output_contact_map must be a boolean, but got {type(output_contact_map)}.",
                )
            postprocess_params["output_contact_map"] = output_contact_map
        return preprocess_params, {}, postprocess_params

    def __init__(self, *args, threshold: float | None = None, output_contact_map: bool | None = None, **kwargs):
        super().__init__(*args, **kwargs)
        if self.framework != "pt":
            raise NotImplementedError("Only PyTorch is supported for RNA secondary structure prediction.")
        if threshold is not None:
            if threshold >= 1:
                raise PipelineException(
                    "rna-secondary-structure",
                    self.model.base_model_prefix,
                    f"Threshold must be less than 1, but got {threshold}.",
                )
            if threshold <= 0:
                raise PipelineException(
                    "rna-secondary-structure",
                    self.model.base_model_prefix,
                    f"Threshold must be greater than 0, but got {threshold}.",
                )
            self.threshold = threshold
        if output_contact_map is not None:
            if not isinstance(output_contact_map, bool):
                raise PipelineException(
                    "rna-secondary-structure",
                    self.model.base_model_prefix,
                    f"output_contact_map must be a boolean, but got {type(output_contact_map)}.",
                )
            self.output_contact_map = output_contact_map

    def __call__(self, inputs, **kwargs):
        """
        Predict the secondary structure of the RNA sequence(s) given as inputs.

        Args:
            inputs (`str` or `List[str]`):
                One or several RNA sequences.
            threshold (`float`, *optional*):
                The threshold to use for determining if a contact is present or not. If not provided, the default is
                `0.5`. The value must be between 0 and 1.
            output_contact_map (`bool`, *optional*):
                Whether to output the contact map along with the secondary structure. If not provided, the default is
                `False`.

        Return:
            A list or a list of list of `dict`: Each result comes as list of dictionaries with the following keys:

            - **sequence** (`str`) -- The corresponding input with the mask token prediction.
            - **score** (`float`) -- The corresponding probability.
            - **token** (`int`) -- The predicted token id (to replace the masked one).
            - **token_str** (`str`) -- The predicted token (to replace the masked one).
        """
        outputs = super().__call__(inputs, **kwargs)
        if isinstance(inputs, list) and len(inputs) == 1:
            return outputs[0]
        return outputs

__call__

Python
__call__(inputs, **kwargs)

Predict the secondary structure of the RNA sequence(s) given as inputs.

Parameters:

Name Type Description Default
inputs
`str` or `List[str]`

One or several RNA sequences.

required
threshold
`float`, *optional*

The threshold to use for determining if a contact is present or not. If not provided, the default is 0.5. The value must be between 0 and 1.

required
output_contact_map
`bool`, *optional*

Whether to output the contact map along with the secondary structure. If not provided, the default is False.

required
Return

A list or a list of list of dict: Each result comes as list of dictionaries with the following keys:

  • sequence (str) – The corresponding input with the mask token prediction.
  • score (float) – The corresponding probability.
  • token (int) – The predicted token id (to replace the masked one).
  • token_str (str) – The predicted token (to replace the masked one).
Source code in multimolecule/pipelines/rna_secondary_structure.py
Python
def __call__(self, inputs, **kwargs):
    """
    Predict the secondary structure of the RNA sequence(s) given as inputs.

    Args:
        inputs (`str` or `List[str]`):
            One or several RNA sequences.
        threshold (`float`, *optional*):
            The threshold to use for determining if a contact is present or not. If not provided, the default is
            `0.5`. The value must be between 0 and 1.
        output_contact_map (`bool`, *optional*):
            Whether to output the contact map along with the secondary structure. If not provided, the default is
            `False`.

    Return:
        A list or a list of list of `dict`: Each result comes as list of dictionaries with the following keys:

        - **sequence** (`str`) -- The corresponding input with the mask token prediction.
        - **score** (`float`) -- The corresponding probability.
        - **token** (`int`) -- The predicted token id (to replace the masked one).
        - **token_str** (`str`) -- The predicted token (to replace the masked one).
    """
    outputs = super().__call__(inputs, **kwargs)
    if isinstance(inputs, list) and len(inputs) == 1:
        return outputs[0]
    return outputs