class SpliceVariantEffectPipeline(Pipeline):
"""
Splice variant-effect pipeline.
The pipeline scores a reference sequence and an alternative sequence and returns the alternative-minus-reference
delta. Models with a native paired-input variant-effect head, such as MMSplice and MTSplice, are called once with
both sequences. Other models are scored on reference and alternative separately.
"""
top_k: int | None = 20
def preprocess(
self,
inputs: str | Mapping[str, str],
alternative: str | None = None,
return_tensors: str | None = None,
tokenizer_kwargs: dict[str, Any] | None = None,
**preprocess_parameters,
) -> dict[str, GenericTensor]:
if return_tensors is None:
return_tensors = "pt"
reference, alternative = _resolve_variant_inputs(inputs, alternative)
tokenizer_kwargs = _tokenizer_kwargs(tokenizer_kwargs)
reference_inputs = self.tokenizer(reference, return_tensors=return_tensors, **tokenizer_kwargs)
alternative_inputs = self.tokenizer(alternative, return_tensors=return_tensors, **tokenizer_kwargs)
return {
**reference_inputs,
"reference_sequence": reference,
"alternative_sequence": alternative,
"alternative_input_ids": alternative_inputs["input_ids"],
"alternative_attention_mask": alternative_inputs.get("attention_mask"),
}
def _forward(self, model_inputs):
reference_sequence = model_inputs.pop("reference_sequence")
alternative_sequence = model_inputs.pop("alternative_sequence")
alternative_input_ids = model_inputs.pop("alternative_input_ids")
alternative_attention_mask = model_inputs.pop("alternative_attention_mask", None)
model_type = _model_type(self.model)
if model_type in {"mmsplice", "mtsplice"}:
outputs = self.model(
**model_inputs,
alternative_input_ids=alternative_input_ids,
alternative_attention_mask=alternative_attention_mask,
)
outputs["reference_sequence"] = reference_sequence
outputs["alternative_sequence"] = alternative_sequence
return outputs
reference_outputs = self.model(**model_inputs)
alternative_inputs = {
"input_ids": alternative_input_ids,
"attention_mask": alternative_attention_mask,
}
alternative_outputs = self.model(
**{key: value for key, value in alternative_inputs.items() if value is not None}
)
return {
"reference_outputs": reference_outputs,
"alternative_outputs": alternative_outputs,
"reference_sequence": reference_sequence,
"alternative_sequence": alternative_sequence,
}
def postprocess(self, model_outputs, top_k: int | None = None):
if top_k is None:
top_k = self.top_k
reference_sequence = model_outputs["reference_sequence"]
alternative_sequence = model_outputs["alternative_sequence"]
config = getattr(self.model, "config", None)
if "reference_outputs" in model_outputs:
reference_scores, channels = _processed_scores(
model_outputs["reference_outputs"],
model=self.model,
config=config,
task=self.task,
)
alternative_scores, _ = _processed_scores(
model_outputs["alternative_outputs"],
model=self.model,
config=config,
task=self.task,
)
delta_scores = alternative_scores - reference_scores
return _variant_effect_result(
reference_sequence,
alternative_sequence,
delta_scores,
channels,
reference_scores=reference_scores,
alternative_scores=alternative_scores,
top_k=top_k,
)
delta_scores, channels = _processed_scores(
model_outputs,
model=self.model,
config=config,
task=self.task,
)
return _variant_effect_result(reference_sequence, alternative_sequence, delta_scores, channels, top_k=top_k)
def _sanitize_parameters(
self,
alternative: str | None = None,
top_k: int | None = None,
tokenizer_kwargs: dict[str, Any] | None = None,
):
preprocess_params: dict[str, Any] = {}
if alternative is not None:
preprocess_params["alternative"] = alternative
if tokenizer_kwargs is not None:
preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
postprocess_params: dict[str, Any] = {}
if top_k is not None:
if top_k <= 0:
raise PipelineException(self.task, _model_prefix(self.model), "top_k must be a positive integer.")
postprocess_params["top_k"] = top_k
return preprocess_params, {}, postprocess_params
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not isinstance(self.model, torch.nn.Module):
raise NotImplementedError("Only PyTorch is supported for splice variant-effect prediction.")
def __call__(self, inputs, alternative: str | None = None, **kwargs):
if alternative is not None:
kwargs["alternative"] = alternative
return super().__call__(inputs, **kwargs)