跳转至

HyenaDNA

HyenaDNA

Pre-trained model on human reference genome using a causal language modeling (CLM) objective with the Hyena operator.

Disclaimer

This is an UNOFFICIAL implementation of the HyenaDNA: Long-Range Genomic Sequence Modeling at Single Nucleotide Resolution by Eric Nguyen, Michael Poli, Marjan Faizi, et al.

The OFFICIAL repository of HyenaDNA is at HazyResearch/hyena-dna.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing HyenaDNA did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

HyenaDNA is a decoder-only model pre-trained on the human reference genome with single nucleotide tokenization in a self-supervised fashion. This means that the model was trained on the raw nucleotides of DNA sequences only, with an automatic process to generate inputs and labels from those sequences. Please refer to the Training Details section for more information on the training process.

HyenaDNA replaces the attention mechanism with the Hyena operator — a subquadratic sequence mixer based on implicit long convolutions. This enables O(L log L) complexity for sequence modeling, allowing the model to handle context lengths up to 1 million base pairs at single nucleotide resolution.

The Hyena operator uses:

  • Implicit filters: MLP-parameterized convolution kernels with learned positional embeddings
  • Element-wise gating: Multiplicative interactions between projected input channels
  • FFT convolution: Efficient computation via the Fast Fourier Transform

Variants

HyenaDNA was pretrained across 5 architectures (Table A.1 in the paper), of which 4 have released checkpoints. We provide one converted checkpoint per architecture, using the longest available context length:

Note

The original repository releases 7 checkpoints with inconsistent naming (e.g., “large” and “medium” share the same architecture, and multiple “tiny” variants are identical except for context length). We retain only the best (longest context) checkpoint per architecture and use a consistent size-based naming scheme. The mapping from original to MultiMolecule names is:

Original Checkpoint MultiMolecule Name Architecture
hyenadna-large-1m-seqlen-hf hyenadna-large 8 layers, 256 dim
hyenadna-small-32k-seqlen-hf hyenadna-medium 4 layers, 256 dim
hyenadna-tiny-1k-seqlen-d256-hf hyenadna-small 2 layers, 256 dim
hyenadna-tiny-16k-seqlen-d128-hf hyenadna-tiny 2 layers, 128 dim

The (4, 128) architecture from the paper has no released checkpoint. The remaining 3 original checkpoints (hyenadna-medium-450k-seqlen-hf, hyenadna-medium-160k-seqlen-hf, hyenadna-tiny-1k-seqlen-hf) share the same architecture as one of the above but with shorter context lengths, and are not provided.

Model Specification

Variants Num Layers Hidden Size Intermediate Size Num Parameters (M) FLOPs (G) MACs (G) Max Num Tokens
HyenaDNA-large 8 256 1024 6.62 6.69 3.35 1,000,002
HyenaDNA-medium 4 3.34 3.35 1.67 32,770
HyenaDNA-small 2 1.71 1.67 0.84 1,026
HyenaDNA-tiny 128 512 0.45 0.44 0.22 16,386

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Text Generation

You can use this model directly with a pipeline for text generation:

Python
1
2
3
4
5
import multimolecule  # you must import multimolecule to register models
from transformers import pipeline

generator = pipeline("text-generation", model="multimolecule/hyenadna-large")
output = generator("ATCGATCGATCG", max_new_tokens=50)

Downstream Use

Extract Features

Here is how to use this model to get the features of a given sequence in PyTorch:

Python
from multimolecule import DnaTokenizer, HyenaDnaModel


tokenizer = DnaTokenizer.from_pretrained("multimolecule/hyenadna-large")
model = HyenaDnaModel.from_pretrained("multimolecule/hyenadna-large")

text = "ATCGATCGATCGATCG"
input = tokenizer(text, return_tensors="pt")

output = model(**input)

Sequence Classification / Regression

Note

This model is not fine-tuned for any specific task. You will need to fine-tune the model on a downstream task to use it for sequence classification or regression.

Here is how to use this model as backbone to fine-tune for a sequence-level task in PyTorch:

Python
import torch
from multimolecule import DnaTokenizer, HyenaDnaForSequencePrediction


tokenizer = DnaTokenizer.from_pretrained("multimolecule/hyenadna-large")
model = HyenaDnaForSequencePrediction.from_pretrained("multimolecule/hyenadna-large")

text = "ATCGATCGATCGATCG"
input = tokenizer(text, return_tensors="pt")
label = torch.tensor([1])

output = model(**input, labels=label)

Token Classification / Regression

Note

This model is not fine-tuned for any specific task. You will need to fine-tune the model on a downstream task to use it for token classification or regression.

Here is how to use this model as backbone to fine-tune for a nucleotide-level task in PyTorch:

Python
import torch
from multimolecule import DnaTokenizer, HyenaDnaForTokenPrediction


tokenizer = DnaTokenizer.from_pretrained("multimolecule/hyenadna-large")
model = HyenaDnaForTokenPrediction.from_pretrained("multimolecule/hyenadna-large")

text = "ATCGATCGATCGATCG"
input = tokenizer(text, return_tensors="pt")
label = torch.randint(2, (len(text), ))

output = model(**input, labels=label)

Training Details

HyenaDNA used Causal Language Modeling (CLM) as the pre-training objective: given a DNA sequence, the model is trained to predict the next nucleotide token autoregressively.

Training Data

The HyenaDNA model was pre-trained on the human reference genome (GRCh38). The training data consists of single nucleotide-level DNA sequences from all human chromosomes. Sequences are tokenized at the individual character level (A, C, G, T, N) without k-mer encoding.

The dataset is split into training and test sets by chromosome, with held-out chromosomes used for evaluation.

Training Procedure

Preprocessing

HyenaDNA used causal language modeling (CLM) as the pre-training objective: given a DNA sequence of length L, the model predicts the next nucleotide at each position, i.e., predicting token x_{t+1} given x_1, …, x_t.

Single nucleotide tokenization is used with a vocabulary of 12 tokens: A, C, G, T, N, and special tokens (CLS, SEP, BOS, MASK, PAD, RESERVED, UNK).

Sequence Length Warm-up

A key training strategy in HyenaDNA is progressive sequence length warm-up. Training begins with short sequences and gradually increases the context length:

  1. Training starts with sequences of length L=64.
  2. The sequence length is doubled at each warm-up stage (64 → 128 → 256 → … → target length).
  3. This strategy enables stable training at very long context lengths that would be difficult to train from scratch.

Pre-training

The model was trained on up to 8 NVIDIA A100 (80GB) GPUs.

  • Batch size: 64 – 256
  • Steps: 10,000 – 20,000
  • Optimizer: AdamW
  • Learning rate: 1.5e-4 – 6e-4
  • Learning rate scheduler: Cosine
  • Weight decay: 0.1

Citation

BibTeX
1
2
3
4
5
6
7
@inproceedings{nguyen2023hyenadna,
  title={Hyena{DNA}: Long-Range Genomic Sequence Modeling at Single Nucleotide Resolution},
  author={Eric Nguyen and Michael Poli and Marjan Faizi and Armin W Thomas and Michael Wornow and Callum Birch-Sykes and Stefano Massaroli and Aman Patel and Clayton M. Rabideau and Yoshua Bengio and Stefano Ermon and Christopher Re and Stephen Baccus},
  booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
  year={2023},
  url={https://openreview.net/forum?id=ubzNoJjOKj}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the HyenaDNA paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.hyenadna

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

HyenaDnaConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a HyenaDnaModel. It is used to instantiate a HyenaDNA model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the HyenaDNA LongSafari/hyenadna-medium-160k-seqlen-hf architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the HyenaDNA model. Defines the number of different tokens that can be represented by the input_ids passed when calling [HyenaDnaModel].

11

hidden_size

int

Dimensionality of the model layers.

256

num_hidden_layers

int

Number of hidden layers (Hyena blocks) in the model.

8

intermediate_size

int | None

Dimensionality of the feed-forward layer. If None, defaults to 4 * hidden_size.

None

embedding_dropout

float

The dropout probability for the embedding layer.

0.1

hidden_dropout

float

The dropout probability within the Hyena operator.

0.0

max_position_embeddings

int

The maximum sequence length that this model might ever be used with.

160002

initializer_range

float

The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

0.02

layer_norm_eps

float

The epsilon used by the layer normalization layers.

1e-05

hyena_order

int

Order of the Hyena recurrence. Controls the number of element-wise gating steps.

2

filter_order

int

Width of the implicit filter MLP (number of hidden units).

64

short_filter_order

int

Kernel size of the short depthwise convolution applied before the Hyena recurrence.

3

filter_emb_dim

int

Dimensionality of the positional embedding fed to the implicit filter MLP. Must be odd and >= 3. Computed as (1 time) + (2 * num_frequency_bands).

5

num_inner_mlps

int

Number of inner linear layers inside the implicit filter MLP.

2

activation_freq

int

Frequency multiplier for the Sin activation function in the implicit filter.

10

filter_dropout

float

The dropout probability for the implicit filter.

0.0

use_bias

bool

Whether to use bias in the implicit filter.

True

train_freq

bool

Whether the Sin activation frequencies are learnable parameters.

True

pad_vocab_size_multiple

int

Pad the vocabulary size to be a multiple of this value (for GPU performance).

8

head

HeadConfig | None

The configuration of the head.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import HyenaDnaConfig, HyenaDnaModel
>>> # Initializing a HyenaDNA multimolecule/hyenadna style configuration
>>> configuration = HyenaDnaConfig()
>>> # Initializing a model (with random weights) from the multimolecule/hyenadna style configuration
>>> model = HyenaDnaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/hyenadna/configuration_hyenadna.py
Python
class HyenaDnaConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`HyenaDnaModel`][multimolecule.models.HyenaDnaModel]. It is used to instantiate a HyenaDNA model according
    to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will
    yield a similar configuration to that of the HyenaDNA
    [LongSafari/hyenadna-medium-160k-seqlen-hf](https://huggingface.co/LongSafari/hyenadna-medium-160k-seqlen-hf)
    architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Vocabulary size of the HyenaDNA model. Defines the number of different tokens that can be represented by
            the `input_ids` passed when calling [`HyenaDnaModel`].
        hidden_size:
            Dimensionality of the model layers.
        num_hidden_layers:
            Number of hidden layers (Hyena blocks) in the model.
        intermediate_size:
            Dimensionality of the feed-forward layer. If `None`, defaults to `4 * hidden_size`.
        embedding_dropout:
            The dropout probability for the embedding layer.
        hidden_dropout:
            The dropout probability within the Hyena operator.
        max_position_embeddings:
            The maximum sequence length that this model might ever be used with.
        initializer_range:
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        layer_norm_eps:
            The epsilon used by the layer normalization layers.
        hyena_order:
            Order of the Hyena recurrence. Controls the number of element-wise gating steps.
        filter_order:
            Width of the implicit filter MLP (number of hidden units).
        short_filter_order:
            Kernel size of the short depthwise convolution applied before the Hyena recurrence.
        filter_emb_dim:
            Dimensionality of the positional embedding fed to the implicit filter MLP.
            Must be odd and >= 3. Computed as `(1 time) + (2 * num_frequency_bands)`.
        num_inner_mlps:
            Number of inner linear layers inside the implicit filter MLP.
        activation_freq:
            Frequency multiplier for the Sin activation function in the implicit filter.
        filter_dropout:
            The dropout probability for the implicit filter.
        use_bias:
            Whether to use bias in the implicit filter.
        train_freq:
            Whether the Sin activation frequencies are learnable parameters.
        pad_vocab_size_multiple:
            Pad the vocabulary size to be a multiple of this value (for GPU performance).
        head:
            The configuration of the head.

    Examples:
        >>> from multimolecule import HyenaDnaConfig, HyenaDnaModel
        >>> # Initializing a HyenaDNA multimolecule/hyenadna style configuration
        >>> configuration = HyenaDnaConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/hyenadna style configuration
        >>> model = HyenaDnaModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "hyenadna"

    def __init__(
        self,
        vocab_size: int = 11,
        hidden_size: int = 256,
        num_hidden_layers: int = 8,
        intermediate_size: int | None = None,
        embedding_dropout: float = 0.1,
        hidden_dropout: float = 0.0,
        max_position_embeddings: int = 160002,
        initializer_range: float = 0.02,
        layer_norm_eps: float = 1e-5,
        hyena_order: int = 2,
        filter_order: int = 64,
        short_filter_order: int = 3,
        filter_emb_dim: int = 5,
        num_inner_mlps: int = 2,
        activation_freq: int = 10,
        filter_dropout: float = 0.0,
        use_bias: bool = True,
        train_freq: bool = True,
        pad_vocab_size_multiple: int = 8,
        head: HeadConfig | None = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        if intermediate_size is None:
            intermediate_size = 4 * hidden_size
        if filter_emb_dim < 3 or filter_emb_dim % 2 == 0:
            raise ValueError(f"filter_emb_dim must be odd and at least 3, but got {filter_emb_dim}.")
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.intermediate_size = intermediate_size
        self.embedding_dropout = embedding_dropout
        self.hidden_dropout = hidden_dropout
        self.max_position_embeddings = max_position_embeddings
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.hyena_order = hyena_order
        self.filter_order = filter_order
        self.short_filter_order = short_filter_order
        self.filter_emb_dim = filter_emb_dim
        self.num_inner_mlps = num_inner_mlps
        self.activation_freq = activation_freq
        self.filter_dropout = filter_dropout
        self.use_bias = use_bias
        self.train_freq = train_freq
        self.pad_vocab_size_multiple = pad_vocab_size_multiple
        self.head = HeadConfig(**head) if head is not None else None

HyenaDnaForCausalLM

Bases: HyenaDnaPreTrainedModel, GenerationMixin

Examples:

Python Console Session
1
2
3
4
>>> import torch
>>> from multimolecule import HyenaDnaConfig, HyenaDnaForCausalLM
>>> config = HyenaDnaConfig()
>>> model = HyenaDnaForCausalLM(config)
Source code in multimolecule/models/hyenadna/modeling_hyenadna.py
Python
class HyenaDnaForCausalLM(HyenaDnaPreTrainedModel, GenerationMixin):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import HyenaDnaConfig, HyenaDnaForCausalLM
        >>> config = HyenaDnaConfig()
        >>> model = HyenaDnaForCausalLM(config)
    """

    _tied_weights_keys = {"lm_head.weight": "model.embeddings.word_embeddings.weight"}

    def __init__(self, config: HyenaDnaConfig):
        super().__init__(config)
        self.model = HyenaDnaModel(config, add_pooling_layer=False)
        vocab_size = config.vocab_size
        if vocab_size % config.pad_vocab_size_multiple != 0:
            vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
        self.lm_head = nn.Linear(config.hidden_size, vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.model.embeddings.word_embeddings = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        output_hidden_states: bool | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> CausalLMOutput:
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state
        logits = self.lm_head(hidden_states).float()

        loss = None
        if labels is not None:
            loss = self.loss_function(
                logits=logits,
                labels=labels,
                vocab_size=self.config.vocab_size,
                **kwargs,
            )

        return CausalLMOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
        )

HyenaDnaForSequencePrediction

Bases: HyenaDnaPreTrainedModel

Examples:

Python Console Session
1
2
3
4
>>> import torch
>>> from multimolecule import HyenaDnaConfig, HyenaDnaForSequencePrediction
>>> config = HyenaDnaConfig()
>>> model = HyenaDnaForSequencePrediction(config)
Source code in multimolecule/models/hyenadna/modeling_hyenadna.py
Python
class HyenaDnaForSequencePrediction(HyenaDnaPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import HyenaDnaConfig, HyenaDnaForSequencePrediction
        >>> config = HyenaDnaConfig()
        >>> model = HyenaDnaForSequencePrediction(config)
    """

    def __init__(self, config: HyenaDnaConfig):
        super().__init__(config)
        self.model = HyenaDnaModel(config)
        self.sequence_head = SequencePredictionHead(config)
        self.head_config = self.sequence_head.config

        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        output_hidden_states: bool | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> SequencePredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            **kwargs,
        )
        output = self.sequence_head(outputs, labels)
        logits, loss = output.logits, output.loss

        return SequencePredictorOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
        )

HyenaDnaForTokenPrediction

Bases: HyenaDnaPreTrainedModel

Examples:

Python Console Session
1
2
3
4
>>> import torch
>>> from multimolecule import HyenaDnaConfig, HyenaDnaForTokenPrediction
>>> config = HyenaDnaConfig()
>>> model = HyenaDnaForTokenPrediction(config)
Source code in multimolecule/models/hyenadna/modeling_hyenadna.py
Python
class HyenaDnaForTokenPrediction(HyenaDnaPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import HyenaDnaConfig, HyenaDnaForTokenPrediction
        >>> config = HyenaDnaConfig()
        >>> model = HyenaDnaForTokenPrediction(config)
    """

    def __init__(self, config: HyenaDnaConfig):
        super().__init__(config)
        self.model = HyenaDnaModel(config, add_pooling_layer=False)
        self.token_head = TokenPredictionHead(config)
        self.head_config = self.token_head.config

        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        output_hidden_states: bool | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> TokenPredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            **kwargs,
        )
        output = self.token_head(outputs, attention_mask, input_ids, labels)
        logits, loss = output.logits, output.loss

        return TokenPredictorOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
        )

HyenaDnaModel

Bases: HyenaDnaPreTrainedModel

Examples:

Python Console Session
1
2
3
4
>>> import torch
>>> from multimolecule import HyenaDnaConfig, HyenaDnaModel
>>> config = HyenaDnaConfig()
>>> model = HyenaDnaModel(config)
Source code in multimolecule/models/hyenadna/modeling_hyenadna.py
Python
class HyenaDnaModel(HyenaDnaPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import HyenaDnaConfig, HyenaDnaModel
        >>> config = HyenaDnaConfig()
        >>> model = HyenaDnaModel(config)
    """

    def __init__(self, config: HyenaDnaConfig, add_pooling_layer: bool = True):
        super().__init__(config)
        self.embeddings = HyenaDnaEmbeddings(config)
        self.dropout = nn.Dropout(config.embedding_dropout)
        self.layers = nn.ModuleList([HyenaDnaBlock(config) for _ in range(config.num_hidden_layers)])
        self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.pooler = HyenaDnaPooler(config) if add_pooling_layer else None
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        output_hidden_states: bool | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[Tensor, ...] | BaseModelOutputWithPoolingAndCrossAttentions:
        # Hyena's FFT-based long convolutions require a fixed sequence length; materialise
        # NestedTensor to dense + mask before entering the block stack.
        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embeddings(input_ids)
        if attention_mask is not None:
            attention_mask = attention_mask.to(device=inputs_embeds.device, dtype=torch.bool)
            hidden_mask = attention_mask.unsqueeze(-1)
        else:
            hidden_mask = None

        hidden_states = self.dropout(inputs_embeds)
        if hidden_mask is not None:
            hidden_states = hidden_states * hidden_mask.to(hidden_states.dtype)
        all_hidden_states: tuple[Tensor, ...] = ()

        for layer in self.layers:
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            if self.gradient_checkpointing and self.training:
                hidden_states = self._gradient_checkpointing_func(layer.__call__, hidden_states)
            else:
                hidden_states = layer(hidden_states)
            if hidden_mask is not None:
                hidden_states = hidden_states * hidden_mask.to(hidden_states.dtype)

        hidden_states = self.final_layer_norm(hidden_states.to(dtype=self.final_layer_norm.weight.dtype))
        if hidden_mask is not None:
            hidden_states = hidden_states * hidden_mask.to(hidden_states.dtype)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        pooled_output = self.pooler(hidden_states) if self.pooler is not None else None

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=hidden_states,
            pooler_output=pooled_output,
            hidden_states=all_hidden_states if output_hidden_states else None,
        )