Skip to content

Regulatory Variant Effect

RegulatoryVariantEffectPipeline

Bases: Pipeline

Regulatory variant-effect pipeline.

The pipeline scores a reference DNA sequence and an alternative DNA sequence and returns the alternative-minus-reference delta. Models with a native paired-input variant-effect head are called once; other models are scored on reference and alternative separately.

Source code in multimolecule/pipelines/regulatory.py
Python
class RegulatoryVariantEffectPipeline(Pipeline):
    """
    Regulatory variant-effect pipeline.

    The pipeline scores a reference DNA sequence and an alternative DNA sequence and returns the
    alternative-minus-reference delta. Models with a native paired-input variant-effect head are called once; other
    models are scored on reference and alternative separately.
    """

    top_k: int | None = 20

    def preprocess(
        self,
        inputs: str | Mapping[str, Any],
        alternative: str | None = None,
        features: Any | None = None,
        alternative_features: Any | None = None,
        return_tensors: str | None = None,
        tokenizer_kwargs: dict[str, Any] | None = None,
        **preprocess_parameters,
    ) -> dict[str, GenericTensor]:
        if return_tensors is None:
            return_tensors = "pt"
        reference, alternative, features, alternative_features = _resolve_variant_inputs(
            inputs, alternative, features, alternative_features
        )
        tokenizer_kwargs = _tokenizer_kwargs(tokenizer_kwargs)
        reference_inputs = self.tokenizer(reference, return_tensors=return_tensors, **tokenizer_kwargs)
        alternative_inputs = self.tokenizer(alternative, return_tensors=return_tensors, **tokenizer_kwargs)
        output = {
            **reference_inputs,
            "reference_sequence": reference,
            "alternative_sequence": alternative,
            "alternative_input_ids": alternative_inputs["input_ids"],
            "alternative_attention_mask": alternative_inputs.get("attention_mask"),
        }
        if features is not None:
            output["features"] = _prepare_features(features)
        if alternative_features is not None:
            output["alternative_features"] = _prepare_features(alternative_features)
        return output

    def _forward(self, model_inputs):
        reference_sequence = model_inputs.pop("reference_sequence")
        alternative_sequence = model_inputs.pop("alternative_sequence")
        alternative_input_ids = model_inputs.pop("alternative_input_ids")
        alternative_attention_mask = model_inputs.pop("alternative_attention_mask", None)
        reference_features = model_inputs.pop("features", None)
        alternative_features = model_inputs.pop("alternative_features", reference_features)

        if _model_accepts_argument(self.model, "alternative_input_ids"):
            paired_inputs = {
                **model_inputs,
                "alternative_input_ids": alternative_input_ids,
                "alternative_attention_mask": alternative_attention_mask,
            }
            if reference_features is not None:
                paired_inputs["features"] = reference_features
            if alternative_features is not None and _model_accepts_argument(self.model, "alternative_features"):
                paired_inputs["alternative_features"] = alternative_features
            outputs = _call_model(self.model, paired_inputs)
            outputs["reference_sequence"] = reference_sequence
            outputs["alternative_sequence"] = alternative_sequence
            return outputs

        reference_outputs = _call_model(self.model, {**model_inputs, "features": reference_features})
        alternative_outputs = _call_model(
            self.model,
            {
                "input_ids": alternative_input_ids,
                "attention_mask": alternative_attention_mask,
                "features": alternative_features,
            },
        )
        return {
            "reference_outputs": reference_outputs,
            "alternative_outputs": alternative_outputs,
            "reference_sequence": reference_sequence,
            "alternative_sequence": alternative_sequence,
        }

    def postprocess(self, model_outputs, top_k: int | None = None):
        if top_k is None:
            top_k = self.top_k
        reference_sequence = model_outputs["reference_sequence"]
        alternative_sequence = model_outputs["alternative_sequence"]
        config = getattr(self.model, "config", None)

        if "reference_outputs" in model_outputs:
            reference_scores, channels = _processed_scores(
                model_outputs["reference_outputs"],
                model=self.model,
                config=config,
                sequence_level=True,
                task=self.task,
            )
            alternative_scores, _ = _processed_scores(
                model_outputs["alternative_outputs"],
                model=self.model,
                config=config,
                sequence_level=True,
                task=self.task,
            )
            delta_scores = alternative_scores - reference_scores
            return _variant_effect_result(
                reference_sequence,
                alternative_sequence,
                delta_scores,
                channels,
                reference_scores=reference_scores,
                alternative_scores=alternative_scores,
                axis_name=_variant_axis_name(delta_scores, config),
                top_k=top_k,
            )

        delta_scores, channels = _processed_scores(
            model_outputs,
            model=self.model,
            config=config,
            sequence_level=True,
            task=self.task,
        )
        return _variant_effect_result(
            reference_sequence,
            alternative_sequence,
            delta_scores,
            channels,
            axis_name=_variant_axis_name(delta_scores, config),
            top_k=top_k,
        )

    def _sanitize_parameters(
        self,
        alternative: str | None = None,
        features: Any | None = None,
        alternative_features: Any | None = None,
        top_k: int | None = None,
        tokenizer_kwargs: dict[str, Any] | None = None,
    ):
        preprocess_params: dict[str, Any] = {}
        if alternative is not None:
            preprocess_params["alternative"] = alternative
        if features is not None:
            preprocess_params["features"] = features
        if alternative_features is not None:
            preprocess_params["alternative_features"] = alternative_features
        if tokenizer_kwargs is not None:
            preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
        postprocess_params: dict[str, Any] = {}
        if top_k is not None:
            if top_k <= 0:
                raise PipelineException(self.task, _model_prefix(self.model), "top_k must be a positive integer.")
            postprocess_params["top_k"] = top_k
        return preprocess_params, {}, postprocess_params

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not isinstance(self.model, torch.nn.Module):
            raise NotImplementedError("Only PyTorch is supported for regulatory variant-effect prediction.")

    def __call__(self, inputs, alternative: str | None = None, **kwargs):
        if alternative is not None:
            kwargs["alternative"] = alternative
        return super().__call__(inputs, **kwargs)