跳转至

Splice Variant Effect

SpliceVariantEffectPipeline

Bases: Pipeline

Splice variant-effect pipeline.

The pipeline scores a reference sequence and an alternative sequence and returns the alternative-minus-reference delta. Models with a native paired-input variant-effect head, such as MMSplice and MTSplice, are called once with both sequences. Other models are scored on reference and alternative separately.

Source code in multimolecule/pipelines/splicing.py
Python
class SpliceVariantEffectPipeline(Pipeline):
    """
    Splice variant-effect pipeline.

    The pipeline scores a reference sequence and an alternative sequence and returns the alternative-minus-reference
    delta. Models with a native paired-input variant-effect head, such as MMSplice and MTSplice, are called once with
    both sequences. Other models are scored on reference and alternative separately.
    """

    top_k: int | None = 20

    def preprocess(
        self,
        inputs: str | Mapping[str, str],
        alternative: str | None = None,
        return_tensors: str | None = None,
        tokenizer_kwargs: dict[str, Any] | None = None,
        **preprocess_parameters,
    ) -> dict[str, GenericTensor]:
        if return_tensors is None:
            return_tensors = "pt"
        reference, alternative = _resolve_variant_inputs(inputs, alternative)
        tokenizer_kwargs = _tokenizer_kwargs(tokenizer_kwargs)
        reference_inputs = self.tokenizer(reference, return_tensors=return_tensors, **tokenizer_kwargs)
        alternative_inputs = self.tokenizer(alternative, return_tensors=return_tensors, **tokenizer_kwargs)
        return {
            **reference_inputs,
            "reference_sequence": reference,
            "alternative_sequence": alternative,
            "alternative_input_ids": alternative_inputs["input_ids"],
            "alternative_attention_mask": alternative_inputs.get("attention_mask"),
        }

    def _forward(self, model_inputs):
        reference_sequence = model_inputs.pop("reference_sequence")
        alternative_sequence = model_inputs.pop("alternative_sequence")
        alternative_input_ids = model_inputs.pop("alternative_input_ids")
        alternative_attention_mask = model_inputs.pop("alternative_attention_mask", None)

        model_type = _model_type(self.model)
        if model_type in {"mmsplice", "mtsplice"}:
            outputs = self.model(
                **model_inputs,
                alternative_input_ids=alternative_input_ids,
                alternative_attention_mask=alternative_attention_mask,
            )
            outputs["reference_sequence"] = reference_sequence
            outputs["alternative_sequence"] = alternative_sequence
            return outputs

        reference_outputs = self.model(**model_inputs)
        alternative_inputs = {
            "input_ids": alternative_input_ids,
            "attention_mask": alternative_attention_mask,
        }
        alternative_outputs = self.model(
            **{key: value for key, value in alternative_inputs.items() if value is not None}
        )
        return {
            "reference_outputs": reference_outputs,
            "alternative_outputs": alternative_outputs,
            "reference_sequence": reference_sequence,
            "alternative_sequence": alternative_sequence,
        }

    def postprocess(self, model_outputs, top_k: int | None = None):
        if top_k is None:
            top_k = self.top_k
        reference_sequence = model_outputs["reference_sequence"]
        alternative_sequence = model_outputs["alternative_sequence"]
        config = getattr(self.model, "config", None)

        if "reference_outputs" in model_outputs:
            reference_scores, channels = _processed_scores(
                model_outputs["reference_outputs"],
                model=self.model,
                config=config,
                task=self.task,
            )
            alternative_scores, _ = _processed_scores(
                model_outputs["alternative_outputs"],
                model=self.model,
                config=config,
                task=self.task,
            )
            delta_scores = alternative_scores - reference_scores
            return _variant_effect_result(
                reference_sequence,
                alternative_sequence,
                delta_scores,
                channels,
                reference_scores=reference_scores,
                alternative_scores=alternative_scores,
                top_k=top_k,
            )

        delta_scores, channels = _processed_scores(
            model_outputs,
            model=self.model,
            config=config,
            task=self.task,
        )
        return _variant_effect_result(reference_sequence, alternative_sequence, delta_scores, channels, top_k=top_k)

    def _sanitize_parameters(
        self,
        alternative: str | None = None,
        top_k: int | None = None,
        tokenizer_kwargs: dict[str, Any] | None = None,
    ):
        preprocess_params: dict[str, Any] = {}
        if alternative is not None:
            preprocess_params["alternative"] = alternative
        if tokenizer_kwargs is not None:
            preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
        postprocess_params: dict[str, Any] = {}
        if top_k is not None:
            if top_k <= 0:
                raise PipelineException(self.task, _model_prefix(self.model), "top_k must be a positive integer.")
            postprocess_params["top_k"] = top_k
        return preprocess_params, {}, postprocess_params

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not isinstance(self.model, torch.nn.Module):
            raise NotImplementedError("Only PyTorch is supported for splice variant-effect prediction.")

    def __call__(self, inputs, alternative: str | None = None, **kwargs):
        if alternative is not None:
            kwargs["alternative"] = alternative
        return super().__call__(inputs, **kwargs)