Skip to content

Regulatory Track

RegulatoryTrackPipeline

Bases: Pipeline

Binned regulatory track pipeline for DNA regulatory models.

The pipeline accepts raw nucleotide sequences and returns per-bin regulatory track scores, matching models such as Basenji and Enformer.

Source code in multimolecule/pipelines/regulatory.py
Python
class RegulatoryTrackPipeline(Pipeline):
    """
    Binned regulatory track pipeline for DNA regulatory models.

    The pipeline accepts raw nucleotide sequences and returns per-bin regulatory track scores, matching models such as
    Basenji and Enformer.
    """

    def preprocess(
        self,
        inputs: str | Mapping[str, Any],
        return_tensors: str | None = None,
        tokenizer_kwargs: dict[str, Any] | None = None,
        **preprocess_parameters,
    ) -> dict[str, GenericTensor]:
        if return_tensors is None:
            return_tensors = "pt"
        sequence, _ = _resolve_sequence_inputs(inputs, None)
        tokenizer_kwargs = _tokenizer_kwargs(tokenizer_kwargs)
        return self.tokenizer(sequence, return_tensors=return_tensors, **tokenizer_kwargs)

    def _forward(self, model_inputs):
        outputs = self.model(**model_inputs)
        outputs["input_ids"] = model_inputs["input_ids"]
        outputs["attention_mask"] = model_inputs.get("attention_mask")
        return outputs

    def postprocess(self, model_outputs):
        input_ids = model_outputs["input_ids"]
        attention_mask = model_outputs.get("attention_mask")
        sequences = _decode_sequences(self.tokenizer, input_ids, attention_mask)
        config = getattr(self.model, "config", None)
        scores, channels = _processed_scores(
            model_outputs, model=self.model, config=config, sequence_level=False, task=self.task
        )

        results = [
            _track_prediction_result(sequence, sample_scores, channels=channels, config=config)
            for sequence, sample_scores in zip(sequences, _sample_tensors(scores, len(sequences), token_level=True))
        ]
        if len(results) == 1:
            return results[0]
        return results

    def _sanitize_parameters(self, tokenizer_kwargs: dict[str, Any] | None = None):
        preprocess_params: dict[str, Any] = {}
        if tokenizer_kwargs is not None:
            preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
        return preprocess_params, {}, {}

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not isinstance(self.model, torch.nn.Module):
            raise NotImplementedError("Only PyTorch is supported for regulatory track prediction.")