Skip to content

DeepCpG-DNA

DNA-only convolutional neural network from DeepCpG for predicting per-cell single-cell DNA methylation states from a CpG-centered sequence window.

Disclaimer

This is an UNOFFICIAL implementation of DeepCpG: accurate prediction of single-cell DNA methylation states using deep learning by Christof Angermueller et al.

The OFFICIAL repository of DeepCpG is at cangermueller/deepcpg.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing DeepCpG-DNA did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

DeepCpG-DNA is the DNA submodule of the DeepCpG joint model. It is a 1D convolutional neural network that predicts the per-cell methylation state of a CpG site from a fixed-length 1001 bp DNA window centered on the site. The model consumes a one-hot encoded sequence and applies valid-padded convolutional blocks (Conv1D + ReLU + MaxPool) followed by a dense bottleneck and one binary classification head per single cell in the training dataset. Please refer to the Training Details section for more information on the training process.

The full DeepCpG model combines this DNA submodule with a recurrent CpG-context submodule and a joint head; this MultiMolecule release ports the DNA submodule only.

Variants

The DeepCpG-DNA module is trained per single-cell dataset, so each Hub checkpoint exposes a different number of output cells.

Dataset Architecture Cells Hub repository
Smallwood 2014 serum mESC CnnL2h128 18 deepcpgdna-smallwood2014-serum
Smallwood 2014 2i mESC CnnL3h128 12 deepcpgdna-smallwood2014-2i
Hou 2016 HCC CnnL2h128 25 deepcpgdna-hou2016-hcc
Hou 2016 HepG2 CnnL3h128 6 deepcpgdna-hou2016-hepg2
Hou 2016 mESC CnnL2h128 6 deepcpgdna-hou2016-mesc

Model Specification

Architecture Num Conv Layers Hidden Size Num Cells Num Parameters (M) FLOPs (M) MACs (M) Max Num Tokens
CnnL2h128 2 128 18 4.11 70.63 35.06 1001
CnnL3h128 3 128 12 4.43 165.02 82.18 1001

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Single-Cell Methylation Prediction

You can use this model directly to predict the per-cell methylation state of a 1001 bp DNA window centered on a CpG site:

Python
>>> from multimolecule import DnaTokenizer, DeepCpgDnaForSequencePrediction

>>> model_id = "multimolecule/deepcpgdna-smallwood2014-serum"
>>> tokenizer = DnaTokenizer.from_pretrained(model_id)
>>> model = DeepCpgDnaForSequencePrediction.from_pretrained(model_id)
>>> input = tokenizer("ACGT" * 250 + "A", return_tensors="pt")
>>> output = model(**input)

>>> output.logits.shape
torch.Size([1, 18])

Each logit is a per-cell methylation score for one of the single cells in the chosen training dataset; apply a sigmoid to obtain methylation probabilities.

Interface

  • Input length: fixed 1001 bp DNA window centered on a CpG site
  • Padding: not supported; pad or crop genomic windows so they match sequence_length exactly
  • Alphabet: DNA (A, C, G, T); the additional N channel from the MultiMolecule streamline alphabet is zero-initialised in the converter
  • Output: per-cell methylation logits; the number of cells is dataset-specific (see Variants table)

Training Details

DeepCpG-DNA was trained to predict the per-cell methylation state of CpG sites from their flanking DNA context.

Training Data

DeepCpG-DNA was trained on single-cell bisulfite sequencing datasets:

  • Smallwood 2014: scBS-seq profiles of mouse embryonic stem cells, with 18 serum and 12 2i mESCs (excluding two serum cells whose methylation pattern deviated strongly from the remainder).
  • Hou 2016: scRRBS-seq profiles of 25 human hepatocellular carcinoma (HCC) cells, 6 human heptoplastoma-derived (HepG2) cells, and 6 mESCs, restricted to CpG sites covered by at least four reads.

Each training example is a 1001 bp DNA window centered on a CpG site, with a per-cell binary methylation label (methylated, unmethylated, or missing). Chromosomes were split into training, validation, and test sets to avoid sequence leakage.

Training Procedure

Pre-training

The model was trained to minimize a per-cell binary cross-entropy loss, comparing its predicted per-cell methylation probabilities (sigmoid of the per-cell logits) against the observed single-cell bisulfite labels. Missing labels are masked out during training.

  • Optimizer: Adam
  • Loss: Per-cell binary cross-entropy
  • Regularization: Dropout and L2 weight decay

Citation

BibTeX
@article{angermueller2017deepcpg,
  author    = {Angermueller, Christof and Lee, Heather J. and Reik, Wolf and Stegle, Oliver},
  title     = {{DeepCpG}: accurate prediction of single-cell {DNA} methylation states using deep learning},
  journal   = {Genome Biology},
  volume    = 18,
  number    = 1,
  pages     = {67},
  year      = 2017,
  publisher = {BioMed Central},
  doi       = {10.1186/s13059-017-1189-z}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the DeepCpG paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.deepcpgdna

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

DeepCpgDnaConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a DeepCpgDnaModel. It is used to instantiate a DeepCpG-DNA model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DeepCpG DNA submodule (cangermueller/deepcpg) CnnL2h128 architecture as distributed for the Smallwood2014 serum mESC checkpoint.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the DeepCpG-DNA model. DeepCpG consumes a one-hot encoding of DNA nucleotides, so this also defines the number of input channels of the first convolution. Defaults to 5 to match the MultiMolecule streamline DNA alphabet (A, C, G, T, N); the upstream four-channel kernel is reordered into this slot layout in the converter, leaving the N channel zero. Defaults to 5.

5

sequence_length

int

The fixed length of the DNA window (in base pairs) centered on a CpG site. Defaults to 1001.

1001

conv_channels

list[int] | None

Number of filters for each convolutional layer.

None

conv_kernel_sizes

list[int] | None

Kernel size for each convolutional layer.

None

conv_pool_sizes

list[int] | None

Max-pool size applied after each convolutional layer.

None

bottleneck_size

int

Dimensionality of the dense bottleneck embedding. This is the model’s hidden size. Defaults to 128.

128

hidden_act

str

The non-linear activation function (function or string) in the encoder. If string, "gelu", "relu", "silu" and "gelu_new" are supported.

'relu'

hidden_dropout

float

The dropout probability for the bottleneck.

0.0

num_labels

int

Number of output labels. DeepCpG-DNA predicts per-cell methylation state, so this equals the number of single cells in the training dataset and is dataset-specific. Defaults to 18 to match the Smallwood2014 serum mESC checkpoint.

18

head

HeadConfig | None

The configuration of the prediction head. Defaults to a per-cell binary methylation head (problem_type="binary"), matching DeepCpG-DNA’s per-cell methylation task.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import DeepCpgDnaConfig, DeepCpgDnaModel
>>> # Initializing a DeepCpG-DNA multimolecule/deepcpgdna style configuration
>>> configuration = DeepCpgDnaConfig()
>>> # Initializing a model (with random weights) from the multimolecule/deepcpgdna style configuration
>>> model = DeepCpgDnaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/deepcpgdna/configuration_deepcpgdna.py
Python
class DeepCpgDnaConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`DeepCpgDnaModel`][multimolecule.models.DeepCpgDnaModel]. It is used to instantiate a DeepCpG-DNA model according
    to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will
    yield a similar configuration to that of the DeepCpG DNA submodule
    ([cangermueller/deepcpg](https://github.com/cangermueller/deepcpg)) `CnnL2h128` architecture as distributed for the
    Smallwood2014 serum mESC checkpoint.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Vocabulary size of the DeepCpG-DNA model. DeepCpG consumes a one-hot encoding of DNA nucleotides, so this
            also defines the number of input channels of the first convolution. Defaults to 5 to match the
            MultiMolecule `streamline` DNA alphabet (`A`, `C`, `G`, `T`, `N`); the upstream four-channel kernel is
            reordered into this slot layout in the converter, leaving the `N` channel zero.
            Defaults to 5.
        sequence_length:
            The fixed length of the DNA window (in base pairs) centered on a CpG site.
            Defaults to 1001.
        conv_channels:
            Number of filters for each convolutional layer.
        conv_kernel_sizes:
            Kernel size for each convolutional layer.
        conv_pool_sizes:
            Max-pool size applied after each convolutional layer.
        bottleneck_size:
            Dimensionality of the dense bottleneck embedding. This is the model's hidden size.
            Defaults to 128.
        hidden_act:
            The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
            `"silu"` and `"gelu_new"` are supported.
        hidden_dropout:
            The dropout probability for the bottleneck.
        num_labels:
            Number of output labels. DeepCpG-DNA predicts per-cell methylation state, so this equals the number of
            single cells in the training dataset and is **dataset-specific**.
            Defaults to 18 to match the Smallwood2014 serum mESC checkpoint.
        head:
            The configuration of the prediction head. Defaults to a per-cell binary methylation head
            (`problem_type="binary"`), matching DeepCpG-DNA's per-cell methylation task.

    Examples:
        >>> from multimolecule import DeepCpgDnaConfig, DeepCpgDnaModel
        >>> # Initializing a DeepCpG-DNA multimolecule/deepcpgdna style configuration
        >>> configuration = DeepCpgDnaConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/deepcpgdna style configuration
        >>> model = DeepCpgDnaModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "deepcpgdna"

    def __init__(
        self,
        vocab_size: int = 5,
        sequence_length: int = 1001,
        conv_channels: list[int] | None = None,
        conv_kernel_sizes: list[int] | None = None,
        conv_pool_sizes: list[int] | None = None,
        bottleneck_size: int = 128,
        hidden_act: str = "relu",
        hidden_dropout: float = 0.0,
        num_labels: int = 18,
        head: HeadConfig | None = None,
        **kwargs,
    ):
        super().__init__(num_labels=num_labels, **kwargs)
        # Upstream `CnnL2h128`: Conv1D(128, 11) -> MaxPool(4) -> Conv1D(256, 3) -> MaxPool(2) -> Flatten -> Dense(128).
        if conv_channels is None:
            conv_channels = [128, 256]
        if conv_kernel_sizes is None:
            conv_kernel_sizes = [11, 3]
        if conv_pool_sizes is None:
            conv_pool_sizes = [4, 2]
        if not (len(conv_channels) == len(conv_kernel_sizes) == len(conv_pool_sizes)):
            raise ValueError(
                "conv_channels, conv_kernel_sizes and conv_pool_sizes must have the same length, but got "
                f"{len(conv_channels)}, {len(conv_kernel_sizes)} and {len(conv_pool_sizes)}."
            )
        if sequence_length <= 0:
            raise ValueError(f"sequence_length must be positive, but got {sequence_length}.")
        if bottleneck_size <= 0:
            raise ValueError(f"bottleneck_size must be positive, but got {bottleneck_size}.")
        self.vocab_size = vocab_size
        self.sequence_length = sequence_length
        self.conv_channels = conv_channels
        self.conv_kernel_sizes = conv_kernel_sizes
        self.conv_pool_sizes = conv_pool_sizes
        self.bottleneck_size = bottleneck_size
        self.hidden_size = bottleneck_size
        self.hidden_act = hidden_act
        self.hidden_dropout = hidden_dropout
        # DeepCpG-DNA performs per-cell binary methylation prediction. The MultiMolecule `problem_type` convention
        # lives on the head config, since the Transformers base config only accepts the HF `problem_type` literals.
        if head is None:
            head = HeadConfig(problem_type="binary")
        else:
            head = HeadConfig(head)
            if head.problem_type is None:
                head.problem_type = "binary"
        self.head = head

DeepCpgDnaForSequencePrediction

Bases: DeepCpgDnaPreTrainedModel

The per-cell methylation (final dense) layer of DeepCpG-DNA is dataset-specific: it has one output per single cell in the training dataset. num_labels therefore equals the number of cells in the chosen dataset (18 for the shipped Smallwood2014 serum mESC checkpoint) and is exposed through the shared SequencePredictionHead decoder.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import DeepCpgDnaConfig, DeepCpgDnaForSequencePrediction, DnaTokenizer
>>> config = DeepCpgDnaConfig()
>>> model = DeepCpgDnaForSequencePrediction(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/deepcpgdna")
>>> input = tokenizer(["ACGT" * 250 + "A", "TGCA" * 250 + "T"], return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (2, 18)))
>>> output["logits"].shape
torch.Size([2, 18])
>>> output["loss"]
tensor(..., grad_fn=<...>)
Source code in multimolecule/models/deepcpgdna/modeling_deepcpgdna.py
Python
class DeepCpgDnaForSequencePrediction(DeepCpgDnaPreTrainedModel):
    """
    The per-cell methylation (final dense) layer of DeepCpG-DNA is **dataset-specific**: it has one output per single
    cell in the training dataset. `num_labels` therefore equals the number of cells in the chosen dataset (18 for the
    shipped Smallwood2014 serum mESC checkpoint) and is exposed through the shared
    [`SequencePredictionHead`][multimolecule.SequencePredictionHead] decoder.

    Examples:
        >>> import torch
        >>> from multimolecule import DeepCpgDnaConfig, DeepCpgDnaForSequencePrediction, DnaTokenizer
        >>> config = DeepCpgDnaConfig()
        >>> model = DeepCpgDnaForSequencePrediction(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/deepcpgdna")
        >>> input = tokenizer(["ACGT" * 250 + "A", "TGCA" * 250 + "T"], return_tensors="pt")
        >>> output = model(**input, labels=torch.randint(2, (2, 18)))
        >>> output["logits"].shape
        torch.Size([2, 18])
        >>> output["loss"]  # doctest:+ELLIPSIS
        tensor(..., grad_fn=<...>)
    """

    def __init__(self, config: DeepCpgDnaConfig):
        super().__init__(config)
        self.model = DeepCpgDnaModel(config)
        self.sequence_head = SequencePredictionHead(config)
        self.head_config = self.sequence_head.config

        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        id2label = getattr(self.config, "id2label", None)
        if id2label is not None:
            labels = [str(id2label.get(index, f"cell_{index}")) for index in range(self.config.num_labels)]
            if any(label != f"LABEL_{index}" for index, label in enumerate(labels)):
                return labels
        return [f"cell_{index}" for index in range(self.config.num_labels)]

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Tuple[Tensor, ...] | SequencePredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        output = self.sequence_head(outputs, labels)
        logits, loss = output.logits, output.loss

        return SequencePredictorOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def postprocess(self, outputs: Any) -> Tensor:
        return torch.sigmoid(outputs["logits"])

DeepCpgDnaModel

Bases: DeepCpgDnaPreTrainedModel

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> from multimolecule import DeepCpgDnaConfig, DeepCpgDnaModel, DnaTokenizer
>>> config = DeepCpgDnaConfig()
>>> model = DeepCpgDnaModel(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/deepcpgdna")
>>> input = tokenizer(["ACGT" * 250 + "A", "TGCA" * 250 + "T"], return_tensors="pt")
>>> output = model(**input)
>>> output["pooler_output"].shape
torch.Size([2, 128])
Source code in multimolecule/models/deepcpgdna/modeling_deepcpgdna.py
Python
class DeepCpgDnaModel(DeepCpgDnaPreTrainedModel):
    """
    Examples:
        >>> from multimolecule import DeepCpgDnaConfig, DeepCpgDnaModel, DnaTokenizer
        >>> config = DeepCpgDnaConfig()
        >>> model = DeepCpgDnaModel(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/deepcpgdna")
        >>> input = tokenizer(["ACGT" * 250 + "A", "TGCA" * 250 + "T"], return_tensors="pt")
        >>> output = model(**input)
        >>> output["pooler_output"].shape
        torch.Size([2, 128])
    """

    def __init__(self, config: DeepCpgDnaConfig):
        super().__init__(config)
        self.embeddings = DeepCpgDnaEmbedding(config)
        self.encoder = DeepCpgDnaEncoder(config)
        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutputWithPoolingAndCrossAttentions:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        # The DeepCpG-DNA encoder collapses the sequence dimension through its dense bottleneck, so the final
        # bottleneck embedding is both the model's last hidden state and its pooled representation.
        sequence_output = self.encoder(embedding_output)

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=sequence_output,
        )

DeepCpgDnaPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/deepcpgdna/modeling_deepcpgdna.py
Python
class DeepCpgDnaPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = DeepCpgDnaConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["DeepCpgDnaConvLayer"]

    @torch.no_grad()
    def _init_weights(self, module: nn.Module):
        super()._init_weights(module)
        # Upstream uses Keras `glorot_uniform` (the default) for both convolutions and dense layers.
        # Use transformers.initialization wrappers (imported as `init`); they check the `_is_hf_initialized`
        # flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, (nn.Conv1d, nn.Linear)):
            init.xavier_uniform_(module.weight)
            if module.bias is not None:
                init.zeros_(module.bias)