class RegulatoryActivityPipeline(Pipeline):
"""
Regulatory activity pipeline for DNA regulatory models.
The pipeline accepts raw nucleotide sequences and returns whole-sequence regulatory scores. Models with auxiliary
numeric features, such as Xpresso, can receive them through the optional `features=` argument.
"""
def preprocess(
self,
inputs: str | Mapping[str, Any],
features: Any | None = None,
return_tensors: str | None = None,
tokenizer_kwargs: dict[str, Any] | None = None,
**preprocess_parameters,
) -> dict[str, GenericTensor]:
if return_tensors is None:
return_tensors = "pt"
sequence, features = _resolve_sequence_inputs(inputs, features)
tokenizer_kwargs = _tokenizer_kwargs(tokenizer_kwargs)
model_inputs = self.tokenizer(sequence, return_tensors=return_tensors, **tokenizer_kwargs)
if features is not None:
model_inputs["features"] = _prepare_features(features)
return model_inputs
def _forward(self, model_inputs):
outputs = _call_model(self.model, model_inputs)
outputs["input_ids"] = model_inputs["input_ids"]
outputs["attention_mask"] = model_inputs.get("attention_mask")
return outputs
def postprocess(self, model_outputs):
input_ids = model_outputs["input_ids"]
attention_mask = model_outputs.get("attention_mask")
sequences = _decode_sequences(self.tokenizer, input_ids, attention_mask)
config = getattr(self.model, "config", None)
scores, channels = _processed_scores(
model_outputs, model=self.model, config=config, sequence_level=True, task=self.task
)
results = [
_sequence_prediction_result(sequence, sample_scores, channels=channels)
for sequence, sample_scores in zip(sequences, _sample_tensors(scores, len(sequences), token_level=False))
]
if len(results) == 1:
return results[0]
return results
def _sanitize_parameters(
self,
features: Any | None = None,
tokenizer_kwargs: dict[str, Any] | None = None,
):
preprocess_params: dict[str, Any] = {}
if features is not None:
preprocess_params["features"] = features
if tokenizer_kwargs is not None:
preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
return preprocess_params, {}, {}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not isinstance(self.model, torch.nn.Module):
raise NotImplementedError("Only PyTorch is supported for regulatory activity.")