跳转至

Basset

Deep convolutional neural network for predicting chromatin accessibility (DNase I hypersensitivity) from DNA sequence.

Disclaimer

This is an UNOFFICIAL implementation of Basset: learning the regulatory code of the accessible genome with deep convolutional neural networks by David R. Kelley, et al.

The OFFICIAL repository of Basset is at davek44/Basset.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing Basset did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

Basset is a convolutional neural network (CNN) trained to predict the chromatin accessibility (DNase I hypersensitivity) of a DNA sequence across 164 cell types. The model consumes a fixed-length 600 bp one-hot encoded DNA sequence and applies three convolutional blocks (convolution, batch normalization, ReLU, and max pooling) followed by two fully-connected blocks before a multi-label binary classification head. Please refer to the Training Details section for more information on the training process.

Model Specification

Num Conv Layers Num FC Layers Hidden Size Num Parameters (M) FLOPs (G) MACs (G) Max Num Tokens
3 2 1000 4.14 0.30 0.15 600

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Chromatin Accessibility Prediction

You can use this model directly to predict the DNase I hypersensitivity of a DNA sequence:

Python
>>> import torch
>>> from multimolecule import DnaTokenizer, BassetForSequencePrediction

>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/basset")
>>> model = BassetForSequencePrediction.from_pretrained("multimolecule/basset")
>>> input = tokenizer("ACGT" * 150, return_tensors="pt")
>>> output = model(**input)

>>> output.logits.shape
torch.Size([1, 164])

Interface

  • Input length: fixed 600 bp DNA window
  • Output: 164 per-cell-type accessibility logits (multi-label binary)

Training Details

Basset was trained to predict the chromatin accessibility of DNA sequences across a panel of cell types.

Training Data

Basset was trained on DNase I hypersensitivity peaks from ENCODE and the Roadmap Epigenomics project, covering 164 cell types. Each 600 bp genomic interval is labeled with a binary vector indicating which of the 164 cell types show an accessibility peak overlapping that interval.

Training Procedure

Pre-training

The model was trained to minimize a multi-label binary cross-entropy loss, comparing its predicted per-cell-type accessibility probabilities against the observed DNase I hypersensitivity labels.

  • Optimizer: RMSprop
  • Loss: Multi-label binary cross-entropy
  • Regularization: Batch normalization and dropout

Citation

BibTeX
@article{kelley2016basset,
  author    = {Kelley, David R. and Snoek, Jasper and Rinn, John L.},
  title     = {Basset: learning the regulatory code of the accessible genome with deep convolutional neural networks},
  journal   = {Genome Research},
  volume    = 26,
  number    = 7,
  pages     = {990--999},
  year      = 2016,
  publisher = {Cold Spring Harbor Laboratory Press},
  doi       = {10.1101/gr.200535.115}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If MultiMolecule supports your research, please cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the Basset paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

API Reference

BassetConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a BassetModel. It is used to instantiate a Basset model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Basset davek44/Basset architecture.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the Basset model. Basset consumes a one-hot encoding of the four DNA nucleotides, so this also defines the number of input channels of the first convolution. Defaults to 4.

4

sequence_length

int

The fixed length of the input DNA sequence in base pairs. Defaults to 600.

600

num_conv_layers

int

Number of convolutional layers in the encoder.

3

conv_channels

list[int] | None

Number of filters for each convolutional layer.

None

conv_kernel_sizes

list[int] | None

Kernel size for each convolutional layer.

None

conv_pool_sizes

list[int] | None

Max-pool size applied after each convolutional layer.

None

fc_sizes

list[int] | None

Hidden dimensionality of each fully-connected layer.

None

hidden_act

str

The non-linear activation function (function or string) in the encoder. If string, "gelu", "relu", "silu" and "gelu_new" are supported.

'relu'

hidden_dropout

float

The dropout probability for the fully-connected layers.

0.3

batch_norm_eps

float

The epsilon used by the batch normalization layers.

1e-05

batch_norm_momentum

float

The momentum used by the batch normalization layers.

0.1

num_labels

int

Number of output labels. Basset predicts DNase I hypersensitivity across 164 cell types. Defaults to 164.

164

head

HeadConfig | None

The configuration of the prediction head. Defaults to a multi-label binary classification head (problem_type="multilabel"), matching Basset’s DNase I hypersensitivity prediction task.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import BassetConfig, BassetModel
>>> # Initializing a Basset multimolecule/basset style configuration
>>> configuration = BassetConfig()
>>> # Initializing a model (with random weights) from the multimolecule/basset style configuration
>>> model = BassetModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/basset/configuration_basset.py
Python
class BassetConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`BassetModel`][multimolecule.models.BassetModel]. It is used to instantiate a Basset model according to the
    specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a
    similar configuration to that of the Basset [davek44/Basset](https://github.com/davek44/Basset) architecture.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Vocabulary size of the Basset model. Basset consumes a one-hot encoding of the four DNA nucleotides, so
            this also defines the number of input channels of the first convolution.
            Defaults to 4.
        sequence_length:
            The fixed length of the input DNA sequence in base pairs.
            Defaults to 600.
        num_conv_layers:
            Number of convolutional layers in the encoder.
        conv_channels:
            Number of filters for each convolutional layer.
        conv_kernel_sizes:
            Kernel size for each convolutional layer.
        conv_pool_sizes:
            Max-pool size applied after each convolutional layer.
        fc_sizes:
            Hidden dimensionality of each fully-connected layer.
        hidden_act:
            The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
            `"silu"` and `"gelu_new"` are supported.
        hidden_dropout:
            The dropout probability for the fully-connected layers.
        batch_norm_eps:
            The epsilon used by the batch normalization layers.
        batch_norm_momentum:
            The momentum used by the batch normalization layers.
        num_labels:
            Number of output labels. Basset predicts DNase I hypersensitivity across 164 cell types.
            Defaults to 164.
        head:
            The configuration of the prediction head. Defaults to a multi-label binary classification head
            (`problem_type="multilabel"`), matching Basset's DNase I hypersensitivity prediction task.

    Examples:
        >>> from multimolecule import BassetConfig, BassetModel
        >>> # Initializing a Basset multimolecule/basset style configuration
        >>> configuration = BassetConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/basset style configuration
        >>> model = BassetModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "basset"

    def __init__(
        self,
        vocab_size: int = 4,
        sequence_length: int = 600,
        num_conv_layers: int = 3,
        conv_channels: list[int] | None = None,
        conv_kernel_sizes: list[int] | None = None,
        conv_pool_sizes: list[int] | None = None,
        fc_sizes: list[int] | None = None,
        hidden_act: str = "relu",
        hidden_dropout: float = 0.3,
        batch_norm_eps: float = 1e-5,
        batch_norm_momentum: float = 0.1,
        num_labels: int = 164,
        head: HeadConfig | None = None,
        **kwargs,
    ):
        super().__init__(num_labels=num_labels, **kwargs)
        if conv_channels is None:
            conv_channels = [300, 200, 200]
        if conv_kernel_sizes is None:
            conv_kernel_sizes = [19, 11, 7]
        if conv_pool_sizes is None:
            conv_pool_sizes = [3, 4, 4]
        if fc_sizes is None:
            fc_sizes = [1000, 1000]
        if not (len(conv_channels) == len(conv_kernel_sizes) == len(conv_pool_sizes) == num_conv_layers):
            raise ValueError(
                "conv_channels, conv_kernel_sizes and conv_pool_sizes must each have length num_conv_layers "
                f"({num_conv_layers}), but got {len(conv_channels)}, {len(conv_kernel_sizes)} and "
                f"{len(conv_pool_sizes)}."
            )
        if sequence_length <= 0:
            raise ValueError(f"sequence_length must be positive, but got {sequence_length}.")
        if not fc_sizes:
            raise ValueError("fc_sizes must contain at least one fully-connected layer.")
        self.vocab_size = vocab_size
        self.sequence_length = sequence_length
        self.num_conv_layers = num_conv_layers
        self.conv_channels = conv_channels
        self.conv_kernel_sizes = conv_kernel_sizes
        self.conv_pool_sizes = conv_pool_sizes
        self.fc_sizes = fc_sizes
        self.hidden_size = fc_sizes[-1]
        self.hidden_act = hidden_act
        self.hidden_dropout = hidden_dropout
        self.batch_norm_eps = batch_norm_eps
        self.batch_norm_momentum = batch_norm_momentum
        # Basset performs multi-label binary classification of DNase I hypersensitivity. The MultiMolecule
        # `problem_type` convention lives on the head config, since the Transformers base config only accepts
        # the HF `problem_type` literals.
        if head is None:
            head = HeadConfig(problem_type="multilabel")
        else:
            head = HeadConfig(head)
            if head.problem_type is None:
                head.problem_type = "multilabel"
        self.head = head

BassetForSequencePrediction

Bases: BassetPreTrainedModel

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import BassetConfig, BassetForSequencePrediction, DnaTokenizer
>>> config = BassetConfig()
>>> model = BassetForSequencePrediction(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/basset")
>>> input = tokenizer(["ACGT" * 150, "TGCA" * 150], return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (2, 164)))
>>> output["logits"].shape
torch.Size([2, 164])
>>> output["loss"]
tensor(..., grad_fn=<...>)
Source code in multimolecule/models/basset/modeling_basset.py
Python
class BassetForSequencePrediction(BassetPreTrainedModel):
    """
    Examples:
        >>> import torch
        >>> from multimolecule import BassetConfig, BassetForSequencePrediction, DnaTokenizer
        >>> config = BassetConfig()
        >>> model = BassetForSequencePrediction(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/basset")
        >>> input = tokenizer(["ACGT" * 150, "TGCA" * 150], return_tensors="pt")
        >>> output = model(**input, labels=torch.randint(2, (2, 164)))
        >>> output["logits"].shape
        torch.Size([2, 164])
        >>> output["loss"]  # doctest:+ELLIPSIS
        tensor(..., grad_fn=<...>)
    """

    def __init__(self, config: BassetConfig):
        super().__init__(config)
        self.model = BassetModel(config)
        self.sequence_head = SequencePredictionHead(config)
        self.head_config = self.sequence_head.config

        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        id2label = getattr(self.config, "id2label", None)
        if id2label is not None:
            labels = [str(id2label.get(index, f"dnase_{index}")) for index in range(self.config.num_labels)]
            if any(label != f"LABEL_{index}" for index, label in enumerate(labels)):
                return labels
        return [f"dnase_{index}" for index in range(self.config.num_labels)]

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[Tensor, ...] | SequencePredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        output = self.sequence_head(outputs, labels)
        logits, loss = output.logits, output.loss

        return SequencePredictorOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def postprocess(self, outputs: Any) -> Tensor:
        return torch.sigmoid(outputs["logits"])

BassetModel

Bases: BassetPreTrainedModel

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> from multimolecule import BassetConfig, BassetModel, DnaTokenizer
>>> config = BassetConfig()
>>> model = BassetModel(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/basset")
>>> input = tokenizer(["ACGT" * 150, "TGCA" * 150], return_tensors="pt")
>>> output = model(**input)
>>> output["pooler_output"].shape
torch.Size([2, 1000])
Source code in multimolecule/models/basset/modeling_basset.py
Python
class BassetModel(BassetPreTrainedModel):
    """
    Examples:
        >>> from multimolecule import BassetConfig, BassetModel, DnaTokenizer
        >>> config = BassetConfig()
        >>> model = BassetModel(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/basset")
        >>> input = tokenizer(["ACGT" * 150, "TGCA" * 150], return_tensors="pt")
        >>> output = model(**input)
        >>> output["pooler_output"].shape
        torch.Size([2, 1000])
    """

    def __init__(self, config: BassetConfig):
        super().__init__(config)
        self.embeddings = BassetEmbedding(config)
        self.encoder = BassetEncoder(config)
        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BassetModelOutput:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        # The Basset encoder collapses the sequence dimension through its fully-connected layers, so the
        # final feature vector is both the model's last hidden state and its pooled representation.
        sequence_output = self.encoder(embedding_output)

        return BassetModelOutput(
            last_hidden_state=sequence_output,
            pooler_output=sequence_output,
        )

BassetModelOutput dataclass

Bases: ModelOutput

Base class for outputs of the Basset backbone.

Parameters:

Name Type Description Default

last_hidden_state

`torch.FloatTensor` of shape `(batch_size, hidden_size)`

Final feature vector produced by the Basset encoder.

None

pooler_output

`torch.FloatTensor` of shape `(batch_size, hidden_size)`

Same tensor as last_hidden_state; Basset collapses the sequence dimension in its encoder.

None

attentions

tuple[FloatTensor, ...] | None

Always None; Basset is a convolutional model and has no attention layers. Provided for compatibility with the Transformers output convention.

None
Source code in multimolecule/models/basset/modeling_basset.py
Python
@dataclass
class BassetModelOutput(ModelOutput):
    """
    Base class for outputs of the Basset backbone.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Final feature vector produced by the Basset encoder.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Same tensor as `last_hidden_state`; Basset collapses the sequence dimension in its encoder.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
            when `config.output_hidden_states=True`):
            Tuple containing the one-hot embedding output and the final encoder feature vector.
        attentions:
            Always `None`; Basset is a convolutional model and has no attention layers. Provided for compatibility
            with the Transformers output convention.
    """

    last_hidden_state: torch.FloatTensor | None = None
    pooler_output: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None
    attentions: tuple[torch.FloatTensor, ...] | None = None

BassetPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/basset/modeling_basset.py
Python
class BassetPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BassetConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["BassetConvLayer"]

    @torch.no_grad()
    def _init_weights(self, module: nn.Module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, (nn.Conv1d, nn.Linear)):
            init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm, nn.GroupNorm)):
            init.ones_(module.weight)
            init.zeros_(module.bias)