Skip to content

a2z-chromatin

Disclaimer

This is an UNOFFICIAL implementation of Modeling chromatin state from sequence across angiosperms using recurrent convolutional neural networks by Travis Wrightsman et al.

The OFFICIAL repository of a2z-chromatin is at twrightsman/a2z-regulatory.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing a2z-chromatin did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

a2z-chromatin is a recurrent convolutional neural network (CNN+BLSTM, DanQ topology) trained to predict chromatin state from a fixed-length 600 bp one-hot encoded angiosperm DNA sequence. The single convolutional layer applies 320 filters with a kernel size of 26, followed by dropout and a max-pool over 13 positions; the resulting feature sequence is fed to a bidirectional LSTM (320 units per direction) whose final forward and backward hidden states are concatenated, projected through a 925-unit dense layer, and read out as a single per-window probability.

Two checkpoints are released by the authors: a2z-accessibility (predicts chromatin accessibility from leaf ATAC-seq) and a2z-methylation (predicts lack of CG/CHG/CHH DNA methylation). Both share the same architecture and differ only in the supervision used during training. The canonical MultiMolecule checkpoint is the accessibility model and is registered under regulatory-sequence-prediction; the methylation checkpoint can be converted with the same architecture but belongs to a DNA methylation task rather than the regulatory-sequence task.

Please refer to the Training Details section for more information on the training process.

Model Specification

Num Conv Layers Num LSTM Layers Hidden Size Num Parameters (M) FLOPs (M) MACs (M) Max Num Tokens
1 1 (bidirectional) 925 1.23 14.61 7.30 600

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Chromatin State Prediction

You can use this model directly to predict the chromatin accessibility (or lack of DNA methylation, for the methylation variant) of a 600 bp angiosperm DNA sequence:

Python
>>> import torch
>>> from multimolecule import DnaTokenizer, A2zChromatinForSequencePrediction

>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/a2zchromatin")
>>> model = A2zChromatinForSequencePrediction.from_pretrained("multimolecule/a2zchromatin")
>>> input = tokenizer("ACGT" * 150, return_tensors="pt")
>>> output = model(**input)

>>> output.logits.shape
torch.Size([1, 1])

Interface

  • Input length: fixed 600 bp DNA window
  • Alphabet: DNA IUPAC tokens; ambiguous bases use upstream fractional A/C/G/T mixtures, and non-IUPAC tokens map to zero
  • Output: single per-window logit (binary chromatin accessibility for a2z-accessibility, lack of DNA methylation for a2z-methylation)

Training Details

a2z-chromatin was trained to predict per-window chromatin state across angiosperms using a single shared cross-species DanQ topology.

Training Data

a2z-chromatin was trained on two cross-species data resources:

  • Chromatin accessibility: leaf ATAC-seq peaks from 12 angiosperm species, with each 600 bp genomic interval labelled as accessible or inaccessible.
  • DNA methylation: unmethylated-region (UMR) calls from 10 angiosperm species, with each 600 bp genomic interval labelled as unmethylated or methylated. Unmethylated regions overlap significantly with accessible chromatin in plants, so the two tasks share the same architecture.

Each training example is a 600 bp one-hot encoded DNA sequence with a single binary label.

Training Procedure

Pre-training

Each variant was trained to minimize a binary cross-entropy loss between its sigmoid-activated per-window prediction and the observed accessibility / unmethylation label, sweeping cross-species splits to evaluate generalization.

  • Optimizer: Adam
  • Loss: Binary cross-entropy
  • Regularization: Dropout (0.2 after the convolution, 0.5 after the bidirectional LSTM)

Citation

BibTeX
@article{wrightsman2022a2z,
  author    = {Wrightsman, Travis and Marand, Alexandre P. and Crisp, Peter A. and Springer, Nathan M. and Buckler, Edward S.},
  title     = {Modeling chromatin state from sequence across angiosperms using recurrent convolutional neural networks},
  journal   = {The Plant Genome},
  volume    = 15,
  number    = 3,
  pages     = {e20249},
  year      = 2022,
  publisher = {Wiley},
  doi       = {10.1002/tpg2.20249}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the a2z-chromatin paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.a2zchromatin

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

A2zChromatinConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of an A2zChromatinModel. It is used to instantiate an a2z-chromatin model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the a2z-chromatin twrightsman/a2z-regulatory architecture (DanQ topology trained on angiosperm chromatin data, distributed via Kipoi as a2z-accessibility and a2z-methylation).

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the a2z-chromatin model. Upstream a2z-chromatin consumes four nucleotide channels, but the converted MultiMolecule checkpoint expands the first convolution to the DNA IUPAC tokenizer alphabet so ambiguity tokens reproduce upstream fractional one-hot encodings. Defaults to 16.

16

sequence_length

int

The fixed length of the input DNA sequence in base pairs. Defaults to 600.

600

conv_channels

int

Number of filters in the first (and only) 1D convolution. Defaults to 320.

320

conv_kernel_size

int

Kernel size of the 1D convolution. Defaults to 26.

26

conv_dropout

float

Dropout probability applied after the convolution. Defaults to 0.2.

0.2

pool_size

int

Max-pool window size and stride applied after the convolution. Defaults to 13.

13

lstm_hidden_size

int

Hidden dimensionality of each direction of the bidirectional LSTM. Defaults to 320.

320

lstm_dropout

float

Dropout probability applied after the bidirectional LSTM. Defaults to 0.5.

0.5

fc_size

int

Hidden dimensionality of the fully-connected layer between the LSTM and the prediction head. Defaults to 925.

925

hidden_act

str

The non-linear activation function (function or string) applied after the convolution. If string, "gelu", "relu", "silu" and "gelu_new" are supported. Defaults to "relu".

'relu'

num_labels

int

Number of output labels. a2z-chromatin predicts a single binary target (chromatin accessibility for the a2z-accessibility variant, lack of DNA methylation for the a2z-methylation variant). Defaults to 1.

1

head

HeadConfig | None

The configuration of the prediction head. Defaults to a binary classification head (problem_type="binary"), matching a2z-chromatin’s per-window accessibility / unmethylation task.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import A2zChromatinConfig, A2zChromatinModel
>>> # Initializing an a2z-chromatin multimolecule/a2zchromatin style configuration
>>> configuration = A2zChromatinConfig()
>>> # Initializing a model (with random weights) from the multimolecule/a2zchromatin style configuration
>>> model = A2zChromatinModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/a2zchromatin/configuration_a2zchromatin.py
Python
class A2zChromatinConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of an
    [`A2zChromatinModel`][multimolecule.models.A2zChromatinModel]. It is used to instantiate an a2z-chromatin model
    according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the a2z-chromatin
    [twrightsman/a2z-regulatory](https://github.com/twrightsman/a2z-regulatory) architecture (DanQ topology trained on
    angiosperm chromatin data, distributed via Kipoi as `a2z-accessibility` and `a2z-methylation`).

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Vocabulary size of the a2z-chromatin model. Upstream a2z-chromatin consumes four nucleotide channels, but
            the converted MultiMolecule checkpoint expands the first convolution to the DNA IUPAC tokenizer alphabet so
            ambiguity tokens reproduce upstream fractional one-hot encodings.
            Defaults to 16.
        sequence_length:
            The fixed length of the input DNA sequence in base pairs.
            Defaults to 600.
        conv_channels:
            Number of filters in the first (and only) 1D convolution.
            Defaults to 320.
        conv_kernel_size:
            Kernel size of the 1D convolution.
            Defaults to 26.
        conv_dropout:
            Dropout probability applied after the convolution.
            Defaults to 0.2.
        pool_size:
            Max-pool window size and stride applied after the convolution.
            Defaults to 13.
        lstm_hidden_size:
            Hidden dimensionality of each direction of the bidirectional LSTM.
            Defaults to 320.
        lstm_dropout:
            Dropout probability applied after the bidirectional LSTM.
            Defaults to 0.5.
        fc_size:
            Hidden dimensionality of the fully-connected layer between the LSTM and the prediction head.
            Defaults to 925.
        hidden_act:
            The non-linear activation function (function or string) applied after the convolution. If string, `"gelu"`,
            `"relu"`, `"silu"` and `"gelu_new"` are supported.
            Defaults to `"relu"`.
        num_labels:
            Number of output labels. a2z-chromatin predicts a single binary target (chromatin accessibility for the
            `a2z-accessibility` variant, lack of DNA methylation for the `a2z-methylation` variant).
            Defaults to 1.
        head:
            The configuration of the prediction head. Defaults to a binary classification head
            (`problem_type="binary"`), matching a2z-chromatin's per-window accessibility / unmethylation task.

    Examples:
        >>> from multimolecule import A2zChromatinConfig, A2zChromatinModel
        >>> # Initializing an a2z-chromatin multimolecule/a2zchromatin style configuration
        >>> configuration = A2zChromatinConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/a2zchromatin style configuration
        >>> model = A2zChromatinModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "a2zchromatin"

    def __init__(
        self,
        vocab_size: int = 16,
        sequence_length: int = 600,
        conv_channels: int = 320,
        conv_kernel_size: int = 26,
        conv_dropout: float = 0.2,
        pool_size: int = 13,
        lstm_hidden_size: int = 320,
        lstm_dropout: float = 0.5,
        fc_size: int = 925,
        hidden_act: str = "relu",
        num_labels: int = 1,
        head: HeadConfig | None = None,
        **kwargs,
    ):
        super().__init__(num_labels=num_labels, **kwargs)
        if sequence_length <= 0:
            raise ValueError(f"sequence_length must be positive, but got {sequence_length}.")
        if conv_channels <= 0:
            raise ValueError(f"conv_channels must be positive, but got {conv_channels}.")
        if conv_kernel_size <= 0:
            raise ValueError(f"conv_kernel_size must be positive, but got {conv_kernel_size}.")
        if pool_size <= 0:
            raise ValueError(f"pool_size must be positive, but got {pool_size}.")
        if lstm_hidden_size <= 0:
            raise ValueError(f"lstm_hidden_size must be positive, but got {lstm_hidden_size}.")
        if fc_size <= 0:
            raise ValueError(f"fc_size must be positive, but got {fc_size}.")
        # Upstream DanQ topology uses valid (zero-padding) convolution followed by floor-mode pooling with
        # stride == pool_size; require the configured (sequence_length, conv_kernel_size, pool_size) triple to
        # leave at least one pooled position so the BLSTM has a non-empty input window.
        conv_out_length = sequence_length - conv_kernel_size + 1
        if conv_out_length <= 0:
            raise ValueError(
                f"sequence_length ({sequence_length}) must be greater than conv_kernel_size ({conv_kernel_size})."
            )
        pooled_length = conv_out_length // pool_size
        if pooled_length <= 0:
            raise ValueError(
                f"The configured (sequence_length={sequence_length}, conv_kernel_size={conv_kernel_size}, "
                f"pool_size={pool_size}) leaves no positions after pooling."
            )
        self.vocab_size = vocab_size
        self.sequence_length = sequence_length
        self.conv_channels = conv_channels
        self.conv_kernel_size = conv_kernel_size
        self.conv_dropout = conv_dropout
        self.pool_size = pool_size
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_dropout = lstm_dropout
        self.fc_size = fc_size
        self.hidden_size = fc_size
        self.hidden_act = hidden_act
        # a2z-chromatin performs per-window binary prediction (accessibility or unmethylation). The MultiMolecule
        # `problem_type` convention lives on the head config, since the Transformers base config only accepts the
        # HF `problem_type` literals.
        if head is None:
            head = HeadConfig(problem_type="binary")
        else:
            head = HeadConfig(head)
            if head.problem_type is None:
                head.problem_type = "binary"
        self.head = head

A2zChromatinForSequencePrediction

Bases: A2zChromatinPreTrainedModel

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import A2zChromatinConfig, A2zChromatinForSequencePrediction, DnaTokenizer
>>> config = A2zChromatinConfig()
>>> model = A2zChromatinForSequencePrediction(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/a2zchromatin")
>>> input = tokenizer(["ACGT" * 150, "TGCA" * 150], return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (2, 1)))
>>> output["logits"].shape
torch.Size([2, 1])
>>> output["loss"]
tensor(..., grad_fn=<...>)
Source code in multimolecule/models/a2zchromatin/modeling_a2zchromatin.py
Python
class A2zChromatinForSequencePrediction(A2zChromatinPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import A2zChromatinConfig, A2zChromatinForSequencePrediction, DnaTokenizer
        >>> config = A2zChromatinConfig()
        >>> model = A2zChromatinForSequencePrediction(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/a2zchromatin")
        >>> input = tokenizer(["ACGT" * 150, "TGCA" * 150], return_tensors="pt")
        >>> output = model(**input, labels=torch.randint(2, (2, 1)))
        >>> output["logits"].shape
        torch.Size([2, 1])
        >>> output["loss"]  # doctest:+ELLIPSIS
        tensor(..., grad_fn=<...>)
    """

    def __init__(self, config: A2zChromatinConfig):
        super().__init__(config)
        self.model = A2zChromatinModel(config)
        self.sequence_head = SequencePredictionHead(config)
        self.head_config = self.sequence_head.config

        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        id2label = getattr(self.config, "id2label", None)
        if id2label is not None:
            labels = [str(id2label.get(index, f"chromatin_{index}")) for index in range(self.config.num_labels)]
            if any(label != f"LABEL_{index}" for index, label in enumerate(labels)):
                return labels
        return [f"chromatin_{index}" for index in range(self.config.num_labels)]

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Tuple[Tensor, ...] | SequencePredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        output = self.sequence_head(outputs, labels)
        logits, loss = output.logits, output.loss

        return SequencePredictorOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def postprocess(self, outputs: Any) -> Tensor:
        return torch.sigmoid(outputs["logits"])

A2zChromatinModel

Bases: A2zChromatinPreTrainedModel

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> from multimolecule import A2zChromatinConfig, A2zChromatinModel, DnaTokenizer
>>> config = A2zChromatinConfig()
>>> model = A2zChromatinModel(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/a2zchromatin")
>>> input = tokenizer(["ACGT" * 150, "TGCA" * 150], return_tensors="pt")
>>> output = model(**input)
>>> output["pooler_output"].shape
torch.Size([2, 925])
Source code in multimolecule/models/a2zchromatin/modeling_a2zchromatin.py
Python
class A2zChromatinModel(A2zChromatinPreTrainedModel):
    """
    Examples:
        >>> from multimolecule import A2zChromatinConfig, A2zChromatinModel, DnaTokenizer
        >>> config = A2zChromatinConfig()
        >>> model = A2zChromatinModel(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/a2zchromatin")
        >>> input = tokenizer(["ACGT" * 150, "TGCA" * 150], return_tensors="pt")
        >>> output = model(**input)
        >>> output["pooler_output"].shape
        torch.Size([2, 925])
    """

    def __init__(self, config: A2zChromatinConfig):
        super().__init__(config)
        self.embeddings = A2zChromatinEmbedding(config)
        self.encoder = A2zChromatinEncoder(config)
        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> A2zChromatinModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        # The a2z-chromatin encoder collapses the sequence dimension through the bidirectional LSTM and a dense
        # projection, so the final feature vector is both the model's last hidden state and its pooled representation.
        sequence_output = self.encoder(embedding_output)

        return A2zChromatinModelOutput(
            last_hidden_state=sequence_output,
            pooler_output=sequence_output,
        )

A2zChromatinModelOutput dataclass

Bases: ModelOutput

Base class for outputs of the a2z-chromatin backbone.

Parameters:

Name Type Description Default

last_hidden_state

`torch.FloatTensor` of shape `(batch_size, hidden_size)`

Sequence-level representation produced by the DanQ CNN+BLSTM encoder and dense projection. The upstream Keras model returns only this final feature vector rather than per-position hidden states.

None

pooler_output

`torch.FloatTensor` of shape `(batch_size, hidden_size)`

Alias of last_hidden_state; this is the tensor consumed by SequencePredictionHead.

None

hidden_states

`tuple(torch.FloatTensor)`, *optional*

Always None; a2z-chromatin does not record intermediate hidden states.

None

attentions

`tuple(torch.FloatTensor)`, *optional*

Always None; a2z-chromatin is a convolutional/recurrent model without attention.

None
Source code in multimolecule/models/a2zchromatin/modeling_a2zchromatin.py
Python
@dataclass
class A2zChromatinModelOutput(ModelOutput):
    """
    Base class for outputs of the a2z-chromatin backbone.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Sequence-level representation produced by the DanQ CNN+BLSTM encoder and dense projection. The upstream
            Keras model returns only this final feature vector rather than per-position hidden states.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Alias of `last_hidden_state`; this is the tensor consumed by
            [`SequencePredictionHead`][multimolecule.modules.SequencePredictionHead].
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Always `None`; a2z-chromatin does not record intermediate hidden states.
        attentions (`tuple(torch.FloatTensor)`, *optional*):
            Always `None`; a2z-chromatin is a convolutional/recurrent model without attention.
    """

    last_hidden_state: torch.FloatTensor | None = None
    pooler_output: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None
    attentions: tuple[torch.FloatTensor, ...] | None = None

A2zChromatinPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/a2zchromatin/modeling_a2zchromatin.py
Python
class A2zChromatinPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = A2zChromatinConfig
    base_model_prefix = "model"
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["A2zChromatinEncoder"]

    @torch.no_grad()
    def _init_weights(self, module: nn.Module):
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, (nn.Conv1d, nn.Linear)):
            init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, nn.LSTM):
            for name, parameter in module.named_parameters():
                if "weight" in name:
                    init.kaiming_normal_(parameter, mode="fan_in", nonlinearity="relu")
                elif "bias" in name:
                    init.zeros_(parameter)
        elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm, nn.GroupNorm)):
            init.ones_(module.weight)
            init.zeros_(module.bias)