跳转至

DeepMEL

Convolutional and recurrent neural network for predicting melanoma-specific accessible chromatin regions and chromatin topics directly from DNA sequence.

Disclaimer

This is an UNOFFICIAL implementation of Cross-species analysis of enhancer logic using deep learning by Liesbeth Minnoye, Ibrahim Ihsan Taskiran, et al.

The OFFICIAL repository of DeepMEL is at aertslab/DeepMEL.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing DeepMEL did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

DeepMEL is a hybrid convolutional / recurrent neural network trained to predict 24 melanoma chromatin topics (a 4-MEL melanocytic, a 7-MES mesenchymal-like, and additional accessibility programs) directly from 500 bp DNA sequence. Each input sequence is processed by a shared encoder consisting of a 1D convolution, max pooling, a time-distributed dense projection, and a bidirectional LSTM, followed by a fully-connected layer. The same encoder is applied independently to the forward DNA strand and to its reverse complement; a final 24-way decoder produces a sigmoid probability per topic in each branch, and the two branches’ probabilities are averaged into the model’s prediction. Please refer to the Training Details section for more information on the training process.

Model Specification

Conv Filters Conv Kernel BiLSTM Hidden FC Hidden Num Topics Num Parameters (M) FLOPs (M) MACs (M) Max Num Tokens
128 20 128 256 24 3.44 40.76 20.19 500

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Chromatin Topic Prediction

You can use this model directly to predict the 24 melanoma chromatin-topic activities of a 500 bp DNA sequence:

Python
>>> import torch
>>> from multimolecule import DnaTokenizer, DeepMelForSequencePrediction

>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/deepmel")
>>> model = DeepMelForSequencePrediction.from_pretrained("multimolecule/deepmel")
>>> sequence = "ACGT" * 125
>>> output = model(**tokenizer(sequence, return_tensors="pt"))

>>> output.logits.shape
torch.Size([1, 24])

Interface

  • Input length: fixed 500 bp DNA window
  • Alphabet: ACGT (one-hot encoded); the reverse complement is computed internally
  • Output: 24 chromatin-topic logits (multi-label binary); postprocess returns the branch-averaged sigmoid probability per topic

Training Details

DeepMEL was trained to predict cell-type-specific accessible chromatin topics derived from single-cell ATAC-seq of melanoma cell lines.

Training Data

DeepMEL was trained on accessible genomic intervals derived from melanoma single-cell ATAC-seq experiments and modeled as 24 chromatin topics (including the 4-MEL melanocytic-like and 7-MES mesenchymal-like programs). Each training example is a 500 bp genomic interval labelled with a binary vector indicating which topics are active. Chromosome 2 was held out for validation and testing.

Training Procedure

Pre-training

The model was trained to minimize a multi-label binary cross-entropy loss between the branch-averaged sigmoid probabilities and the observed topic-activity labels.

  • Optimizer: Adam
  • Loss: Multi-label binary cross-entropy
  • Regularization: Dropout (0.2 after pooling, 0.1 LSTM input and recurrent dropout, 0.2 after the BiLSTM, 0.4 before the prediction head)

Citation

BibTeX
@article{minnoye2020deepmel,
  author    = {Minnoye, Liesbeth and Taskiran, Ibrahim Ihsan and Mauduit, David and Fazio, Maurizio and Van Aerschot, Linde and Hulselmans, Gert and Christiaens, Valerie and Makhzami, Samira and Seltenhammer, Monika and Karras, Panagiotis and Primot, Aline and Cadieu, Edouard and van Rooijen, Ellen and Marine, Jean-Christophe and Egidy, Giorgia and Ghanem, Ghanem-Elias and Zon, Leonard and Wouters, Jasper and Aerts, Stein},
  title     = {Cross-species analysis of enhancer logic using deep learning},
  journal   = {Genome Research},
  volume    = 30,
  number    = 12,
  pages     = {1815--1834},
  year      = 2020,
  publisher = {Cold Spring Harbor Laboratory Press},
  doi       = {10.1101/gr.260844.120}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the DeepMEL paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.deepmel

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

DeepMelConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a DeepMelModel. It is used to instantiate a DeepMEL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DeepMEL aertslab/DeepMEL architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the DeepMEL model. Defines the number of feature channels in the one-hot encoded input fed to the first convolution. Defaults to 5.

5

input_length

int

The fixed length (in base pairs) of the input DNA sequence. Defaults to 500.

500

conv_channels

int

Number of output channels (filters) of the first convolution. Defaults to 128.

128

conv_kernel_size

int

Convolution kernel size. Defaults to 20.

20

pool_size

int

Max-pool window applied after the convolution. The convolution stride is 1 and the pool stride matches the pool size, so the effective downsampling factor equals pool_size. Defaults to 10.

10

time_distributed_channels

int

Hidden size of the time-distributed dense layer applied after pooling. Defaults to 128.

128

lstm_hidden_size

int

Hidden size of each direction of the bidirectional LSTM. Defaults to 128.

128

fc_dim

int

Hidden size of the fully-connected layer between the recurrent stack and the prediction head. Defaults to 256.

256

hidden_act

str

The non-linear activation function (function or string) in the encoder. If string, "gelu", "relu", "silu" and "gelu_new" are supported.

'relu'

conv_dropout

float

The dropout probability after the convolutional max-pool block.

0.2

recurrent_dropout

float

The dropout probability after the bidirectional LSTM.

0.2

fc_dropout

float

The dropout probability after the fully-connected layer.

0.4

lstm_dropout

float

The dropout probability applied to the LSTM input weights during training.

0.1

lstm_recurrent_dropout

float

The dropout probability applied to the LSTM recurrent weights during training.

0.1

num_labels

int

Number of multi-label binary topics. DeepMEL predicts 24 melanoma topics (4 MEL + 7 MES + others). Defaults to 24.

24

head

HeadConfig | None

The configuration of the prediction head. Defaults to a multi-label binary classification head (problem_type="multilabel"), matching DeepMEL’s chromatin-topic prediction task.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import DeepMelConfig, DeepMelModel
>>> # Initializing a DeepMEL multimolecule/deepmel style configuration
>>> configuration = DeepMelConfig()
>>> # Initializing a model (with random weights) from the multimolecule/deepmel style configuration
>>> model = DeepMelModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/deepmel/configuration_deepmel.py
Python
class DeepMelConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`DeepMelModel`][multimolecule.models.DeepMelModel]. It is used to instantiate a DeepMEL model according to the
    specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a
    similar configuration to that of the DeepMEL [aertslab/DeepMEL](https://github.com/aertslab/DeepMEL) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Vocabulary size of the DeepMEL model. Defines the number of feature channels in the one-hot encoded input
            fed to the first convolution.
            Defaults to 5.
        input_length:
            The fixed length (in base pairs) of the input DNA sequence.
            Defaults to 500.
        conv_channels:
            Number of output channels (filters) of the first convolution.
            Defaults to 128.
        conv_kernel_size:
            Convolution kernel size.
            Defaults to 20.
        pool_size:
            Max-pool window applied after the convolution. The convolution stride is 1 and the pool stride matches the
            pool size, so the effective downsampling factor equals `pool_size`.
            Defaults to 10.
        time_distributed_channels:
            Hidden size of the time-distributed dense layer applied after pooling.
            Defaults to 128.
        lstm_hidden_size:
            Hidden size of each direction of the bidirectional LSTM.
            Defaults to 128.
        fc_dim:
            Hidden size of the fully-connected layer between the recurrent stack and the prediction head.
            Defaults to 256.
        hidden_act:
            The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
            `"silu"` and `"gelu_new"` are supported.
        conv_dropout:
            The dropout probability after the convolutional max-pool block.
        recurrent_dropout:
            The dropout probability after the bidirectional LSTM.
        fc_dropout:
            The dropout probability after the fully-connected layer.
        lstm_dropout:
            The dropout probability applied to the LSTM input weights during training.
        lstm_recurrent_dropout:
            The dropout probability applied to the LSTM recurrent weights during training.
        num_labels:
            Number of multi-label binary topics. DeepMEL predicts 24 melanoma topics (4 MEL + 7 MES + others).
            Defaults to 24.
        head:
            The configuration of the prediction head. Defaults to a multi-label binary classification head
            (`problem_type="multilabel"`), matching DeepMEL's chromatin-topic prediction task.

    Examples:
        >>> from multimolecule import DeepMelConfig, DeepMelModel
        >>> # Initializing a DeepMEL multimolecule/deepmel style configuration
        >>> configuration = DeepMelConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/deepmel style configuration
        >>> model = DeepMelModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "deepmel"

    def __init__(
        self,
        vocab_size: int = 5,
        input_length: int = 500,
        conv_channels: int = 128,
        conv_kernel_size: int = 20,
        pool_size: int = 10,
        time_distributed_channels: int = 128,
        lstm_hidden_size: int = 128,
        fc_dim: int = 256,
        hidden_act: str = "relu",
        conv_dropout: float = 0.2,
        recurrent_dropout: float = 0.2,
        fc_dropout: float = 0.4,
        lstm_dropout: float = 0.1,
        lstm_recurrent_dropout: float = 0.1,
        num_labels: int = 24,
        head: HeadConfig | None = None,
        **kwargs,
    ):
        super().__init__(num_labels=num_labels, **kwargs)
        if input_length <= 0:
            raise ValueError(f"input_length must be positive, got {input_length}.")
        if conv_kernel_size <= 0:
            raise ValueError(f"conv_kernel_size must be positive, got {conv_kernel_size}.")
        if pool_size <= 0:
            raise ValueError(f"pool_size must be positive, got {pool_size}.")
        if conv_channels <= 0:
            raise ValueError(f"conv_channels must be positive, got {conv_channels}.")
        if input_length < conv_kernel_size:
            raise ValueError(f"input_length ({input_length}) must be at least conv_kernel_size ({conv_kernel_size}).")
        if not 0.0 <= conv_dropout < 1.0:
            raise ValueError(f"conv_dropout must be in [0, 1), got {conv_dropout}.")
        if not 0.0 <= recurrent_dropout < 1.0:
            raise ValueError(f"recurrent_dropout must be in [0, 1), got {recurrent_dropout}.")
        if not 0.0 <= fc_dropout < 1.0:
            raise ValueError(f"fc_dropout must be in [0, 1), got {fc_dropout}.")
        if not 0.0 <= lstm_dropout < 1.0:
            raise ValueError(f"lstm_dropout must be in [0, 1), got {lstm_dropout}.")
        if not 0.0 <= lstm_recurrent_dropout < 1.0:
            raise ValueError(f"lstm_recurrent_dropout must be in [0, 1), got {lstm_recurrent_dropout}.")
        self.vocab_size = vocab_size
        self.input_length = input_length
        self.conv_channels = conv_channels
        self.conv_kernel_size = conv_kernel_size
        self.pool_size = pool_size
        self.time_distributed_channels = time_distributed_channels
        self.lstm_hidden_size = lstm_hidden_size
        self.fc_dim = fc_dim
        # The model's pooled representation (fed into the prediction head) is the per-branch fully-connected output,
        # averaged across the forward and reverse-complement branches. The decoder then maps this to the 24 topics.
        self.hidden_size = fc_dim
        self.hidden_act = hidden_act
        self.conv_dropout = conv_dropout
        self.recurrent_dropout = recurrent_dropout
        self.fc_dropout = fc_dropout
        self.lstm_dropout = lstm_dropout
        self.lstm_recurrent_dropout = lstm_recurrent_dropout
        # DeepMEL performs multi-label binary classification of 24 chromatin topics. The MultiMolecule `problem_type`
        # convention lives on the head config, since the Transformers base config only accepts the HF `problem_type`
        # literals.
        if head is None:
            head = HeadConfig(problem_type="multilabel")
        else:
            head = HeadConfig(head)
            if head.problem_type is None:
                head.problem_type = "multilabel"
        self.head = head

    @property
    def pooled_length(self) -> int:
        """Sequence length after the convolution (valid padding) and max-pooling step."""
        return (self.input_length - self.conv_kernel_size + 1) // self.pool_size

    @property
    def flattened_size(self) -> int:
        """Number of features produced by the per-branch flatten step (input width to the FC layer)."""
        return self.pooled_length * 2 * self.lstm_hidden_size

pooled_length property

Python
pooled_length: int

Sequence length after the convolution (valid padding) and max-pooling step.

flattened_size property

Python
flattened_size: int

Number of features produced by the per-branch flatten step (input width to the FC layer).

DeepMelForSequencePrediction

Bases: DeepMelPreTrainedModel

Examples:

Python Console Session
1
2
3
4
5
6
7
8
9
>>> import torch
>>> from multimolecule import DeepMelConfig, DeepMelForSequencePrediction, DnaTokenizer
>>> config = DeepMelConfig()
>>> model = DeepMelForSequencePrediction(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/deepmel")
>>> input = tokenizer(["ACGT" * 125, "TGCA" * 125], return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (2, 24)).float())
>>> output["logits"].shape
torch.Size([2, 24])
Source code in multimolecule/models/deepmel/modeling_deepmel.py
Python
class DeepMelForSequencePrediction(DeepMelPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import DeepMelConfig, DeepMelForSequencePrediction, DnaTokenizer
        >>> config = DeepMelConfig()
        >>> model = DeepMelForSequencePrediction(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/deepmel")
        >>> input = tokenizer(["ACGT" * 125, "TGCA" * 125], return_tensors="pt")
        >>> output = model(**input, labels=torch.randint(2, (2, 24)).float())
        >>> output["logits"].shape
        torch.Size([2, 24])
    """

    def __init__(self, config: DeepMelConfig):
        super().__init__(config)
        self.model = DeepMelModel(config)
        # The upstream Keras model places the final 24-way Dense (`dense_3`) *inside each branch*, follows it with
        # a sigmoid, and then averages the two branches' probabilities. Because sigmoid is non-linear, that ordering
        # cannot be reproduced by the standard `SequencePredictionHead` (which averages its 256-dim input *before*
        # the decoder); we therefore own the decoder and expose the averaged probability in logit space.
        self.sequence_head = DeepMelSequencePredictionHead(config)
        self.head_config = self.sequence_head.config

        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        id2label = getattr(self.config, "id2label", None)
        if id2label is not None:
            labels = [str(id2label.get(index, f"topic_{index}")) for index in range(self.config.num_labels)]
            if any(label != f"LABEL_{index}" for index, label in enumerate(labels)):
                return labels
        return [f"topic_{index}" for index in range(self.config.num_labels)]

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Tuple[Tensor, ...] | SequencePredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        output = self.sequence_head(outputs, labels)
        logits, loss = output.logits, output.loss

        return SequencePredictorOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def postprocess(self, outputs: Any) -> Tensor:
        return torch.sigmoid(outputs["logits"])

DeepMelModel

Bases: DeepMelPreTrainedModel

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> from multimolecule import DeepMelConfig, DeepMelModel, DnaTokenizer
>>> config = DeepMelConfig()
>>> model = DeepMelModel(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/deepmel")
>>> input = tokenizer(["ACGT" * 125, "TGCA" * 125], return_tensors="pt")
>>> output = model(**input)
>>> output["pooler_output"].shape
torch.Size([2, 256])
Source code in multimolecule/models/deepmel/modeling_deepmel.py
Python
class DeepMelModel(DeepMelPreTrainedModel):
    """
    Examples:
        >>> from multimolecule import DeepMelConfig, DeepMelModel, DnaTokenizer
        >>> config = DeepMelConfig()
        >>> model = DeepMelModel(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/deepmel")
        >>> input = tokenizer(["ACGT" * 125, "TGCA" * 125], return_tensors="pt")
        >>> output = model(**input)
        >>> output["pooler_output"].shape
        torch.Size([2, 256])
    """

    def __init__(self, config: DeepMelConfig):
        super().__init__(config)
        self.embeddings = DeepMelEmbedding(config)
        # Both the forward and reverse-complement branches share the same encoder weights, matching the upstream
        # Keras "siamese" model where `conv1d_1`, `time_distributed_1`, `bidirectional_1` and `dense_2` are each
        # applied to both `input_1` and `input_2` before the final averaging step.
        self.encoder = DeepMelEncoder(config)
        self.pooler = DeepMelPooler()

        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> DeepMelModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        forward_embedding = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        reverse_embedding = self.embeddings.reverse_complement(forward_embedding)

        forward_features = self.encoder(forward_embedding)
        reverse_features = self.encoder(reverse_embedding)
        # The forward and reverse-complement 256-dim FC representations are exposed in the output dataclass so that
        # `DeepMelForSequencePrediction` can run the final 24-way decoder on each branch and average the
        # post-sigmoid topic probabilities, matching the upstream Keras model exactly. The `pooler_output` is the
        # average of the two branches, which is the natural sequence-level embedding for backbone use cases.
        pooled_output = self.pooler(forward_features, reverse_features)

        return DeepMelModelOutput(
            last_hidden_state=forward_features,
            pooler_output=pooled_output,
            forward_hidden_state=forward_features,
            reverse_hidden_state=reverse_features,
        )

DeepMelModelOutput dataclass

Bases: ModelOutput

Base class for outputs of the DeepMEL model.

Parameters:

Name Type Description Default

last_hidden_state

`torch.FloatTensor` of shape `(batch_size, fc_dim)`

Per-branch fully-connected representation of the forward DNA strand (i.e. before averaging with the reverse-complement branch). Useful for strand-specific interpretation.

None

pooler_output

`torch.FloatTensor` of shape `(batch_size, fc_dim)`

Branch-averaged sequence-level representation for backbone use cases. The topic head consumes the forward and reverse-complement branch representations directly.

None

forward_hidden_state

`torch.FloatTensor` of shape `(batch_size, fc_dim)`

Fully-connected representation of the forward DNA strand, before branch averaging.

None

reverse_hidden_state

`torch.FloatTensor` of shape `(batch_size, fc_dim)`

Fully-connected representation of the reverse-complement DNA strand, before branch averaging.

None

hidden_states

`tuple(torch.FloatTensor)`, *optional*

Always None; DeepMEL does not currently expose intermediate hidden states.

None

attentions

`tuple(torch.FloatTensor)`, *optional*

Always None; DeepMEL is a convolutional + recurrent model without explicit attention layers.

None
Source code in multimolecule/models/deepmel/modeling_deepmel.py
Python
@dataclass
class DeepMelModelOutput(ModelOutput):
    """
    Base class for outputs of the DeepMEL model.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, fc_dim)`):
            Per-branch fully-connected representation of the forward DNA strand (i.e. before averaging with the
            reverse-complement branch). Useful for strand-specific interpretation.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, fc_dim)`):
            Branch-averaged sequence-level representation for backbone use cases. The topic head consumes the
            forward and reverse-complement branch representations directly.
        forward_hidden_state (`torch.FloatTensor` of shape `(batch_size, fc_dim)`):
            Fully-connected representation of the forward DNA strand, before branch averaging.
        reverse_hidden_state (`torch.FloatTensor` of shape `(batch_size, fc_dim)`):
            Fully-connected representation of the reverse-complement DNA strand, before branch averaging.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Always `None`; DeepMEL does not currently expose intermediate hidden states.
        attentions (`tuple(torch.FloatTensor)`, *optional*):
            Always `None`; DeepMEL is a convolutional + recurrent model without explicit attention layers.
    """

    last_hidden_state: torch.FloatTensor | None = None
    pooler_output: torch.FloatTensor | None = None
    forward_hidden_state: torch.FloatTensor | None = None
    reverse_hidden_state: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None
    attentions: tuple[torch.FloatTensor, ...] | None = None

DeepMelPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/deepmel/modeling_deepmel.py
Python
class DeepMelPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = DeepMelConfig
    base_model_prefix = "model"
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["DeepMelEncoder"]

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, nn.Conv1d):
            init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, nn.Linear):
            init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            if module.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                init.uniform_(module.bias, -bound, bound)
        elif isinstance(module, DeepMelLstm):
            init.xavier_uniform_(module.weight_ih)
            init.orthogonal_(module.weight_hh)
            init.zeros_(module.bias)