跳转至

Splice Site

SpliceSitePipeline

Bases: Pipeline

Splice-site pipeline for RNA splicing models.

The pipeline accepts raw nucleotide sequences and returns biological position-level scores. It supports classical fixed-window splice-site scorers such as MaxEntScan and per-nucleotide models such as OpenSpliceAI, Pangolin, and SpTransformer.

Source code in multimolecule/pipelines/splicing.py
Python
class SpliceSitePipeline(Pipeline):
    """
    Splice-site pipeline for RNA splicing models.

    The pipeline accepts raw nucleotide sequences and returns biological position-level scores. It supports classical
    fixed-window splice-site scorers such as MaxEntScan and per-nucleotide models such as OpenSpliceAI, Pangolin, and
    SpTransformer.
    """

    threshold: float = 0.5
    output_scores: bool = True
    top_k: int | None = None

    def preprocess(
        self,
        inputs: str,
        return_tensors: str | None = None,
        tokenizer_kwargs: dict[str, Any] | None = None,
        **preprocess_parameters,
    ) -> dict[str, GenericTensor]:
        if return_tensors is None:
            return_tensors = "pt"
        tokenizer_kwargs = _tokenizer_kwargs(tokenizer_kwargs)
        return self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs)

    def _forward(self, model_inputs):
        outputs = self.model(**model_inputs)
        outputs["input_ids"] = model_inputs["input_ids"]
        outputs["attention_mask"] = model_inputs.get("attention_mask")
        return outputs

    def postprocess(
        self,
        model_outputs,
        threshold: float | None = None,
        output_scores: bool | None = None,
        top_k: int | None = None,
    ):
        if threshold is None:
            threshold = self.threshold
        if output_scores is None:
            output_scores = self.output_scores
        if top_k is None:
            top_k = self.top_k

        input_ids = model_outputs["input_ids"]
        attention_mask = model_outputs.get("attention_mask")
        sequences = _decode_sequences(self.tokenizer, input_ids, attention_mask)
        config = getattr(self.model, "config", None)
        scores, channels = _processed_scores(model_outputs, model=self.model, config=config, task=self.task)

        results = []
        for sequence, sample_scores in zip(sequences, _sample_tensors(scores, len(sequences))):
            results.append(
                _splice_site_result(
                    sequence,
                    sample_scores,
                    channels=channels,
                    threshold=threshold,
                    output_scores=output_scores,
                    top_k=top_k,
                )
            )
        if len(results) == 1:
            return results[0]
        return results

    def _sanitize_parameters(
        self,
        threshold: float | None = None,
        output_scores: bool | None = None,
        top_k: int | None = None,
        tokenizer_kwargs: dict[str, Any] | None = None,
    ):
        preprocess_params: dict[str, Any] = {}
        if tokenizer_kwargs is not None:
            preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs

        postprocess_params: dict[str, Any] = {}
        if threshold is not None:
            _check_probability("threshold", threshold, self.task, _model_prefix(self.model))
            postprocess_params["threshold"] = threshold
        if output_scores is not None:
            if not isinstance(output_scores, bool):
                raise PipelineException(
                    self.task,
                    _model_prefix(self.model),
                    f"output_scores must be a boolean, got {type(output_scores)}.",
                )
            postprocess_params["output_scores"] = output_scores
        if top_k is not None:
            if top_k <= 0:
                raise PipelineException(self.task, _model_prefix(self.model), "top_k must be a positive integer.")
            postprocess_params["top_k"] = top_k
        return preprocess_params, {}, postprocess_params

    def __init__(self, *args, threshold: float | None = None, output_scores: bool | None = None, **kwargs):
        super().__init__(*args, **kwargs)
        if not isinstance(self.model, torch.nn.Module):
            raise NotImplementedError("Only PyTorch is supported for splice-site prediction.")
        if threshold is not None:
            _check_probability("threshold", threshold, self.task, _model_prefix(self.model))
            self.threshold = threshold
        if output_scores is not None:
            if not isinstance(output_scores, bool):
                raise PipelineException(
                    self.task,
                    _model_prefix(self.model),
                    f"output_scores must be a boolean, got {type(output_scores)}.",
                )
            self.output_scores = output_scores