class RegulatoryVariantEffectPipeline(Pipeline):
"""
Regulatory variant-effect pipeline.
The pipeline scores a reference DNA sequence and an alternative DNA sequence and returns the
alternative-minus-reference delta. Models with a native paired-input variant-effect head are called once; other
models are scored on reference and alternative separately.
"""
top_k: int | None = 20
def preprocess(
self,
inputs: str | Mapping[str, Any],
alternative: str | None = None,
features: Any | None = None,
alternative_features: Any | None = None,
return_tensors: str | None = None,
tokenizer_kwargs: dict[str, Any] | None = None,
**preprocess_parameters,
) -> dict[str, GenericTensor]:
if return_tensors is None:
return_tensors = "pt"
reference, alternative, features, alternative_features = _resolve_variant_inputs(
inputs, alternative, features, alternative_features
)
tokenizer_kwargs = _tokenizer_kwargs(tokenizer_kwargs)
reference_inputs = self.tokenizer(reference, return_tensors=return_tensors, **tokenizer_kwargs)
alternative_inputs = self.tokenizer(alternative, return_tensors=return_tensors, **tokenizer_kwargs)
output = {
**reference_inputs,
"reference_sequence": reference,
"alternative_sequence": alternative,
"alternative_input_ids": alternative_inputs["input_ids"],
"alternative_attention_mask": alternative_inputs.get("attention_mask"),
}
if features is not None:
output["features"] = _prepare_features(features)
if alternative_features is not None:
output["alternative_features"] = _prepare_features(alternative_features)
return output
def _forward(self, model_inputs):
reference_sequence = model_inputs.pop("reference_sequence")
alternative_sequence = model_inputs.pop("alternative_sequence")
alternative_input_ids = model_inputs.pop("alternative_input_ids")
alternative_attention_mask = model_inputs.pop("alternative_attention_mask", None)
reference_features = model_inputs.pop("features", None)
alternative_features = model_inputs.pop("alternative_features", reference_features)
if _model_accepts_argument(self.model, "alternative_input_ids"):
paired_inputs = {
**model_inputs,
"alternative_input_ids": alternative_input_ids,
"alternative_attention_mask": alternative_attention_mask,
}
if reference_features is not None:
paired_inputs["features"] = reference_features
if alternative_features is not None and _model_accepts_argument(self.model, "alternative_features"):
paired_inputs["alternative_features"] = alternative_features
outputs = _call_model(self.model, paired_inputs)
outputs["reference_sequence"] = reference_sequence
outputs["alternative_sequence"] = alternative_sequence
return outputs
reference_outputs = _call_model(self.model, {**model_inputs, "features": reference_features})
alternative_outputs = _call_model(
self.model,
{
"input_ids": alternative_input_ids,
"attention_mask": alternative_attention_mask,
"features": alternative_features,
},
)
return {
"reference_outputs": reference_outputs,
"alternative_outputs": alternative_outputs,
"reference_sequence": reference_sequence,
"alternative_sequence": alternative_sequence,
}
def postprocess(self, model_outputs, top_k: int | None = None):
if top_k is None:
top_k = self.top_k
reference_sequence = model_outputs["reference_sequence"]
alternative_sequence = model_outputs["alternative_sequence"]
config = getattr(self.model, "config", None)
if "reference_outputs" in model_outputs:
reference_scores, channels = _processed_scores(
model_outputs["reference_outputs"],
model=self.model,
config=config,
sequence_level=True,
task=self.task,
)
alternative_scores, _ = _processed_scores(
model_outputs["alternative_outputs"],
model=self.model,
config=config,
sequence_level=True,
task=self.task,
)
delta_scores = alternative_scores - reference_scores
return _variant_effect_result(
reference_sequence,
alternative_sequence,
delta_scores,
channels,
reference_scores=reference_scores,
alternative_scores=alternative_scores,
axis_name=_variant_axis_name(delta_scores, config),
top_k=top_k,
)
delta_scores, channels = _processed_scores(
model_outputs,
model=self.model,
config=config,
sequence_level=True,
task=self.task,
)
return _variant_effect_result(
reference_sequence,
alternative_sequence,
delta_scores,
channels,
axis_name=_variant_axis_name(delta_scores, config),
top_k=top_k,
)
def _sanitize_parameters(
self,
alternative: str | None = None,
features: Any | None = None,
alternative_features: Any | None = None,
top_k: int | None = None,
tokenizer_kwargs: dict[str, Any] | None = None,
):
preprocess_params: dict[str, Any] = {}
if alternative is not None:
preprocess_params["alternative"] = alternative
if features is not None:
preprocess_params["features"] = features
if alternative_features is not None:
preprocess_params["alternative_features"] = alternative_features
if tokenizer_kwargs is not None:
preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
postprocess_params: dict[str, Any] = {}
if top_k is not None:
if top_k <= 0:
raise PipelineException(self.task, _model_prefix(self.model), "top_k must be a positive integer.")
postprocess_params["top_k"] = top_k
return preprocess_params, {}, postprocess_params
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not isinstance(self.model, torch.nn.Module):
raise NotImplementedError("Only PyTorch is supported for regulatory variant-effect prediction.")
def __call__(self, inputs, alternative: str | None = None, **kwargs):
if alternative is not None:
kwargs["alternative"] = alternative
return super().__call__(inputs, **kwargs)