class SpliceSitePipeline(Pipeline):
"""
Splice-site pipeline for RNA splicing models.
The pipeline accepts raw nucleotide sequences and returns biological position-level scores. It supports classical
fixed-window splice-site scorers such as MaxEntScan and per-nucleotide models such as OpenSpliceAI, Pangolin, and
SpTransformer.
"""
threshold: float = 0.5
output_scores: bool = True
top_k: int | None = None
def preprocess(
self,
inputs: str,
return_tensors: str | None = None,
tokenizer_kwargs: dict[str, Any] | None = None,
**preprocess_parameters,
) -> dict[str, GenericTensor]:
if return_tensors is None:
return_tensors = "pt"
tokenizer_kwargs = _tokenizer_kwargs(tokenizer_kwargs)
return self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs)
def _forward(self, model_inputs):
outputs = self.model(**model_inputs)
outputs["input_ids"] = model_inputs["input_ids"]
outputs["attention_mask"] = model_inputs.get("attention_mask")
return outputs
def postprocess(
self,
model_outputs,
threshold: float | None = None,
output_scores: bool | None = None,
top_k: int | None = None,
):
if threshold is None:
threshold = self.threshold
if output_scores is None:
output_scores = self.output_scores
if top_k is None:
top_k = self.top_k
input_ids = model_outputs["input_ids"]
attention_mask = model_outputs.get("attention_mask")
sequences = _decode_sequences(self.tokenizer, input_ids, attention_mask)
config = getattr(self.model, "config", None)
scores, channels = _processed_scores(model_outputs, model=self.model, config=config, task=self.task)
results = []
for sequence, sample_scores in zip(sequences, _sample_tensors(scores, len(sequences))):
results.append(
_splice_site_result(
sequence,
sample_scores,
channels=channels,
threshold=threshold,
output_scores=output_scores,
top_k=top_k,
)
)
if len(results) == 1:
return results[0]
return results
def _sanitize_parameters(
self,
threshold: float | None = None,
output_scores: bool | None = None,
top_k: int | None = None,
tokenizer_kwargs: dict[str, Any] | None = None,
):
preprocess_params: dict[str, Any] = {}
if tokenizer_kwargs is not None:
preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
postprocess_params: dict[str, Any] = {}
if threshold is not None:
_check_probability("threshold", threshold, self.task, _model_prefix(self.model))
postprocess_params["threshold"] = threshold
if output_scores is not None:
if not isinstance(output_scores, bool):
raise PipelineException(
self.task,
_model_prefix(self.model),
f"output_scores must be a boolean, got {type(output_scores)}.",
)
postprocess_params["output_scores"] = output_scores
if top_k is not None:
if top_k <= 0:
raise PipelineException(self.task, _model_prefix(self.model), "top_k must be a positive integer.")
postprocess_params["top_k"] = top_k
return preprocess_params, {}, postprocess_params
def __init__(self, *args, threshold: float | None = None, output_scores: bool | None = None, **kwargs):
super().__init__(*args, **kwargs)
if not isinstance(self.model, torch.nn.Module):
raise NotImplementedError("Only PyTorch is supported for splice-site prediction.")
if threshold is not None:
_check_probability("threshold", threshold, self.task, _model_prefix(self.model))
self.threshold = threshold
if output_scores is not None:
if not isinstance(output_scores, bool):
raise PipelineException(
self.task,
_model_prefix(self.model),
f"output_scores must be a boolean, got {type(output_scores)}.",
)
self.output_scores = output_scores