Skip to content

scBasset

Disclaimer

This is an UNOFFICIAL implementation of scBasset: sequence-based modeling of single-cell ATAC-seq using convolutional neural networks by Han Yuan et al.

The OFFICIAL repository of scBasset is at calico/scBasset.

Tip

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing scBasset did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

scBasset is a convolutional neural network (CNN) that predicts per-cell chromatin accessibility of a DNA peak sequence. The model consumes a fixed-length 1344 bp one-hot encoded DNA sequence and applies a pre-activation convolution stem, a reducing convolution tower, a pointwise convolution, and a dense bottleneck before a final cell-embedding layer that produces one accessibility logit per single cell.

scBasset uses a pre-activation block layout: each convolution block applies the activation (the sigmoid approximation of GELU, sigmoid(1.702 * x) * x) before the convolution, then batch normalization and max pooling. The dense bottleneck flattens the convolution output in Keras channels-last (length-major) order; this ordering is load-bearing and is reconciled in the MultiMolecule implementation.

Important

The final cell-embedding (dense) layer of scBasset is dataset-specific: it has one row per single cell in the training atlas, so there is no dataset-independent foundation checkpoint. The default scBasset variant uses the Buenrostro2018 hematopoiesis tutorial dataset distributed by the scBasset authors, which has 2034 single cells (so num_labels = 2034). A different scBasset dataset would have a different number of cells and a differently sized cell-embedding layer.

The cell-embedding layer is exposed through the shared SequencePredictionHead; the per-cell accessibility task is modeled as a binary problem (problem_type="binary").

Model Specification

Num Conv Layers Hidden Size Num Cells Num Parameters (M) FLOPs (G) MACs (G) Max Num Tokens
8 32 2034 4.59 0.95 0.47 1344

Usage

The model file depends on the multimolecule library. You can install it using pip:

Bash
pip install multimolecule

Direct Use

Single-Cell Chromatin Accessibility Prediction

You can use this model directly to predict per-cell chromatin accessibility of a DNA peak sequence:

Python
1
2
3
4
5
6
7
8
9
>>> from multimolecule import DnaTokenizer, ScBassetForSequencePrediction

>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/scbasset")
>>> model = ScBassetForSequencePrediction.from_pretrained("multimolecule/scbasset")
>>> input = tokenizer("ACGT" * 336, return_tensors="pt")
>>> output = model(**input)

>>> output.logits.shape
torch.Size([1, 2034])

Each of the 2034 logits is a per-cell accessibility score for the Buenrostro2018 hematopoiesis atlas.

Interface

  • Input length: fixed 1344 bp DNA peak window
  • Output: per-cell accessibility logits (2034 cells in the default Buenrostro2018 hematopoiesis atlas; cell count is dataset-specific)

Training Details

scBasset was trained to predict the per-cell chromatin accessibility of DNA peak sequences across a single-cell ATAC-seq atlas.

Training Data

The scBasset model uses the Buenrostro2018 hematopoiesis tutorial model trained on the Buenrostro et al. 2018 single-cell ATAC-seq hematopoiesis dataset (2034 single cells). Each 1344 bp peak is associated with a per-cell binary accessibility vector.

Training Procedure

Pre-training

The model was trained to minimize a per-cell binary cross-entropy loss, comparing its predicted per-cell accessibility probabilities (sigmoid of the cell-embedding logits) against the observed single-cell ATAC-seq accessibility labels.

  • Optimizer: Adam
  • Loss: Per-cell binary cross-entropy
  • Regularization: Batch normalization and dropout

Citation

BibTeX
@article{yuan2022scbasset,
  author    = {Yuan, Han and Kelley, David R.},
  title     = {scBasset: sequence-based modeling of single-cell ATAC-seq using convolutional neural networks},
  journal   = {Nature Methods},
  volume    = 19,
  number    = 9,
  pages     = {1088--1096},
  year      = 2022,
  publisher = {Nature Publishing Group},
  doi       = {10.1038/s41592-022-01562-8}
}

Note

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

BibTeX
@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the scBasset paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

Text Only
SPDX-License-Identifier: AGPL-3.0-or-later

multimolecule.models.scbasset

DnaTokenizer

Bases: Tokenizer

Tokenizer for DNA sequences.

Parameters:

Name Type Description Default

alphabet

Alphabet | str | List[str] | None

alphabet to use for tokenization.

  • If is None, the standard RNA alphabet will be used.
  • If is a string, it should correspond to the name of a predefined alphabet. The options include
    • standard
    • iupac
    • streamline
    • nucleobase
  • If is an alphabet or a list of characters, that specific alphabet will be used.
None

nmers

int

Size of kmer to tokenize.

1

codon

bool

Whether to tokenize into codons.

False

replace_U_with_T

bool

Whether to replace U with T.

True

do_upper_case

bool

Whether to convert input to uppercase.

True

Examples:

Python Console Session
>>> from multimolecule import DnaTokenizer
>>> tokenizer = DnaTokenizer()
>>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
>>> tokenizer('acgt')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 9, 2]
>>> tokenizer = DnaTokenizer(replace_U_with_T=False)
>>> tokenizer('acgu')["input_ids"]
[1, 6, 7, 8, 3, 2]
>>> tokenizer = DnaTokenizer(nmers=3)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 21, 81, 6, 8, 19, 71, 2]
>>> tokenizer = DnaTokenizer(codon=True)
>>> tokenizer('tataaagta')["input_ids"]
[1, 84, 6, 71, 2]
>>> tokenizer('tataaagtaa')["input_ids"]
Traceback (most recent call last):
ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
Source code in multimolecule/tokenisers/dna/tokenization_dna.py
Python
class DnaTokenizer(Tokenizer):
    """
    Tokenizer for DNA sequences.

    Args:
        alphabet: alphabet to use for tokenization.

            - If is `None`, the standard RNA alphabet will be used.
            - If is a `string`, it should correspond to the name of a predefined alphabet. The options include
                + `standard`
                + `iupac`
                + `streamline`
                + `nucleobase`
            - If is an alphabet or a list of characters, that specific alphabet will be used.
        nmers: Size of kmer to tokenize.
        codon: Whether to tokenize into codons.
        replace_U_with_T: Whether to replace U with T.
        do_upper_case: Whether to convert input to uppercase.

    Examples:
        >>> from multimolecule import DnaTokenizer
        >>> tokenizer = DnaTokenizer()
        >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGTNRYSWKMBDHVX|.*-?')["input_ids"]
        [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 2]
        >>> tokenizer('acgt')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 9, 2]
        >>> tokenizer = DnaTokenizer(replace_U_with_T=False)
        >>> tokenizer('acgu')["input_ids"]
        [1, 6, 7, 8, 3, 2]
        >>> tokenizer = DnaTokenizer(nmers=3)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 21, 81, 6, 8, 19, 71, 2]
        >>> tokenizer = DnaTokenizer(codon=True)
        >>> tokenizer('tataaagta')["input_ids"]
        [1, 84, 6, 71, 2]
        >>> tokenizer('tataaagtaa')["input_ids"]
        Traceback (most recent call last):
        ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10
    """

    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        alphabet: Alphabet | str | List[str] | None = None,
        nmers: int = 1,
        codon: bool = False,
        replace_U_with_T: bool = True,
        do_upper_case: bool = True,
        additional_special_tokens: List | Tuple | None = None,
        **kwargs,
    ):
        if codon and (nmers > 1 and nmers != 3):
            raise ValueError("Codon and nmers cannot be used together.")
        if codon:
            nmers = 3  # set to 3 to get correct vocab
        if not isinstance(alphabet, Alphabet):
            alphabet = get_alphabet(alphabet, nmers=nmers)
        super().__init__(
            alphabet=alphabet,
            nmers=nmers,
            codon=codon,
            replace_U_with_T=replace_U_with_T,
            do_upper_case=do_upper_case,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        self.replace_U_with_T = replace_U_with_T
        self.nmers = nmers
        self.codon = codon

    def _tokenize(self, text: str, **kwargs):
        if self.do_upper_case:
            text = text.upper()
        if self.replace_U_with_T:
            text = text.replace("U", "T")
        if self.codon:
            if len(text) % 3 != 0:
                raise ValueError(
                    f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}"
                )
            return [text[i : i + 3] for i in range(0, len(text), 3)]
        if self.nmers > 1:
            return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)]  # noqa: E203
        return list(text)

ScBassetConfig

Bases: PreTrainedConfig

This is the configuration class to store the configuration of a ScBassetModel. It is used to instantiate a scBasset model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the scBasset calico/scBasset architecture as distributed for the Buenrostro2018 hematopoiesis tutorial checkpoint.

Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.

Parameters:

Name Type Description Default

vocab_size

int

Vocabulary size of the scBasset model. scBasset consumes a one-hot encoding of DNA nucleotides, so this also defines the number of input channels of the first convolution. Defaults to 5 to match the MultiMolecule streamline DNA alphabet (A, C, G, T, N); the upstream four-channel kernel is reordered into this slot layout in the converter, leaving the N channel zero. Defaults to 5.

5

sequence_length

int

The fixed length of the input DNA peak sequence in base pairs. Defaults to 1344.

1344

stem_channels

int

Number of filters in the stem (first) convolution. Defaults to 288.

288

stem_kernel_size

int

Kernel size of the stem convolution. Defaults to 17.

17

stem_pool_size

int

Max-pool size applied after the stem convolution. Defaults to 3.

3

tower_channels

list[int] | None

Number of filters for each convolution in the reducing tower. The upstream architecture derives these from filters_init=288, filters_mult=1.122, repeat=6 with integer rounding.

None

tower_kernel_size

int

Kernel size for each tower convolution. Defaults to 5.

5

tower_pool_size

int

Max-pool size applied after each tower convolution. Defaults to 2.

2

pointwise_channels

int

Number of filters in the final pointwise (kernel size 1) convolution. Defaults to 256.

256

bottleneck_size

int

Dimensionality of the dense bottleneck embedding. This is the model’s hidden size. Defaults to 32.

32

hidden_act

str

The non-linear activation function (function or string) in the encoder. scBasset uses the sigmoid approximation of GELU (sigmoid(1.702 * x) * x), exposed by Transformers as "quick_gelu".

'quick_gelu'

hidden_dropout

float

The dropout probability for the bottleneck.

0.2

batch_norm_eps

float

The epsilon used by the batch normalization layers.

0.001

batch_norm_momentum

float

The momentum used by the batch normalization layers.

0.1

num_labels

int

Number of output labels. scBasset predicts per-cell chromatin accessibility, so this equals the number of single cells in the training atlas and is dataset-specific. The shipped Buenrostro2018 hematopoiesis checkpoint has 2034 cells. Defaults to 2034.

2034

head

HeadConfig | None

The configuration of the prediction head. Defaults to a per-cell binary accessibility head (problem_type="binary"), matching scBasset’s per-cell accessibility task.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> from multimolecule import ScBassetConfig, ScBassetModel
>>> # Initializing a scBasset multimolecule/scbasset style configuration
>>> configuration = ScBassetConfig()
>>> # Initializing a model (with random weights) from the multimolecule/scbasset style configuration
>>> model = ScBassetModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Source code in multimolecule/models/scbasset/configuration_scbasset.py
Python
class ScBassetConfig(PreTrainedConfig):
    r"""
    This is the configuration class to store the configuration of a
    [`ScBassetModel`][multimolecule.models.ScBassetModel]. It is used to instantiate a scBasset model according to the
    specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a
    similar configuration to that of the scBasset [calico/scBasset](https://github.com/calico/scBasset) architecture as
    distributed for the Buenrostro2018 hematopoiesis tutorial checkpoint.

    Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
    control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
    for more information.

    Args:
        vocab_size:
            Vocabulary size of the scBasset model. scBasset consumes a one-hot encoding of DNA nucleotides, so this
            also defines the number of input channels of the first convolution. Defaults to 5 to match the
            MultiMolecule `streamline` DNA alphabet (`A`, `C`, `G`, `T`, `N`); the upstream four-channel kernel is
            reordered into this slot layout in the converter, leaving the `N` channel zero.
            Defaults to 5.
        sequence_length:
            The fixed length of the input DNA peak sequence in base pairs.
            Defaults to 1344.
        stem_channels:
            Number of filters in the stem (first) convolution.
            Defaults to 288.
        stem_kernel_size:
            Kernel size of the stem convolution.
            Defaults to 17.
        stem_pool_size:
            Max-pool size applied after the stem convolution.
            Defaults to 3.
        tower_channels:
            Number of filters for each convolution in the reducing tower. The upstream architecture derives these
            from `filters_init=288`, `filters_mult=1.122`, `repeat=6` with integer rounding.
        tower_kernel_size:
            Kernel size for each tower convolution.
            Defaults to 5.
        tower_pool_size:
            Max-pool size applied after each tower convolution.
            Defaults to 2.
        pointwise_channels:
            Number of filters in the final pointwise (kernel size 1) convolution.
            Defaults to 256.
        bottleneck_size:
            Dimensionality of the dense bottleneck embedding. This is the model's hidden size.
            Defaults to 32.
        hidden_act:
            The non-linear activation function (function or string) in the encoder. scBasset uses the sigmoid
            approximation of GELU (`sigmoid(1.702 * x) * x`), exposed by Transformers as `"quick_gelu"`.
        hidden_dropout:
            The dropout probability for the bottleneck.
        batch_norm_eps:
            The epsilon used by the batch normalization layers.
        batch_norm_momentum:
            The momentum used by the batch normalization layers.
        num_labels:
            Number of output labels. scBasset predicts per-cell chromatin accessibility, so this equals the number
            of single cells in the training atlas and is **dataset-specific**. The shipped Buenrostro2018
            hematopoiesis checkpoint has 2034 cells.
            Defaults to 2034.
        head:
            The configuration of the prediction head. Defaults to a per-cell binary accessibility head
            (`problem_type="binary"`), matching scBasset's per-cell accessibility task.

    Examples:
        >>> from multimolecule import ScBassetConfig, ScBassetModel
        >>> # Initializing a scBasset multimolecule/scbasset style configuration
        >>> configuration = ScBassetConfig()
        >>> # Initializing a model (with random weights) from the multimolecule/scbasset style configuration
        >>> model = ScBassetModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """

    model_type = "scbasset"

    def __init__(
        self,
        vocab_size: int = 5,
        sequence_length: int = 1344,
        stem_channels: int = 288,
        stem_kernel_size: int = 17,
        stem_pool_size: int = 3,
        tower_channels: list[int] | None = None,
        tower_kernel_size: int = 5,
        tower_pool_size: int = 2,
        pointwise_channels: int = 256,
        bottleneck_size: int = 32,
        hidden_act: str = "quick_gelu",
        hidden_dropout: float = 0.2,
        batch_norm_eps: float = 1e-3,
        batch_norm_momentum: float = 0.1,
        num_labels: int = 2034,
        head: HeadConfig | None = None,
        **kwargs,
    ):
        super().__init__(num_labels=num_labels, **kwargs)
        if tower_channels is None:
            # Upstream conv_tower(filters_init=288, filters_mult=1.122, repeat=6) with int(round(...)) rounding.
            tower_channels = [288, 323, 363, 407, 456, 512]
        if sequence_length <= 0:
            raise ValueError(f"sequence_length must be positive, but got {sequence_length}.")
        if bottleneck_size <= 0:
            raise ValueError(f"bottleneck_size must be positive, but got {bottleneck_size}.")
        self.vocab_size = vocab_size
        self.sequence_length = sequence_length
        self.stem_channels = stem_channels
        self.stem_kernel_size = stem_kernel_size
        self.stem_pool_size = stem_pool_size
        self.tower_channels = tower_channels
        self.tower_kernel_size = tower_kernel_size
        self.tower_pool_size = tower_pool_size
        self.pointwise_channels = pointwise_channels
        self.bottleneck_size = bottleneck_size
        self.hidden_size = bottleneck_size
        self.hidden_act = hidden_act
        self.hidden_dropout = hidden_dropout
        self.batch_norm_eps = batch_norm_eps
        self.batch_norm_momentum = batch_norm_momentum
        # scBasset performs per-cell binary accessibility prediction. The MultiMolecule `problem_type` convention
        # lives on the head config, since the Transformers base config only accepts the HF `problem_type` literals.
        if head is None:
            head = HeadConfig(problem_type="binary")
        else:
            head = HeadConfig(head)
            if head.problem_type is None:
                head.problem_type = "binary"
        self.head = head

ScBassetForSequencePrediction

Bases: ScBassetPreTrainedModel

The cell-embedding (final dense) layer of scBasset is dataset-specific: it has one row per single cell in the training atlas. num_labels therefore equals the number of cells in the chosen dataset (2034 for the shipped Buenrostro2018 hematopoiesis checkpoint) and is exposed through the shared SequencePredictionHead decoder.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule import ScBassetConfig, ScBassetForSequencePrediction, DnaTokenizer
>>> config = ScBassetConfig()
>>> model = ScBassetForSequencePrediction(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/scbasset")
>>> input = tokenizer(["ACGT" * 336, "TGCA" * 336], return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (2, 2034)))
>>> output["logits"].shape
torch.Size([2, 2034])
>>> output["loss"]
tensor(..., grad_fn=<...>)
Source code in multimolecule/models/scbasset/modeling_scbasset.py
Python
class ScBassetForSequencePrediction(ScBassetPreTrainedModel):
    """
    The cell-embedding (final dense) layer of scBasset is **dataset-specific**: it has one row per single cell in
    the training atlas. `num_labels` therefore equals the number of cells in the chosen dataset (2034 for the
    shipped Buenrostro2018 hematopoiesis checkpoint) and is exposed through the shared
    [`SequencePredictionHead`][multimolecule.SequencePredictionHead] decoder.

    Examples:
        >>> import torch
        >>> from multimolecule import ScBassetConfig, ScBassetForSequencePrediction, DnaTokenizer
        >>> config = ScBassetConfig()
        >>> model = ScBassetForSequencePrediction(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/scbasset")
        >>> input = tokenizer(["ACGT" * 336, "TGCA" * 336], return_tensors="pt")
        >>> output = model(**input, labels=torch.randint(2, (2, 2034)))
        >>> output["logits"].shape
        torch.Size([2, 2034])
        >>> output["loss"]  # doctest:+ELLIPSIS
        tensor(..., grad_fn=<...>)
    """

    def __init__(self, config: ScBassetConfig):
        super().__init__(config)
        self.model = ScBassetModel(config)
        self.sequence_head = SequencePredictionHead(config)
        self.head_config = self.sequence_head.config

        # Initialize weights and apply final processing
        self.post_init()

    @property
    def output_channels(self) -> list[str]:
        id2label = getattr(self.config, "id2label", None)
        if id2label is not None:
            labels = [str(id2label.get(index, f"cell_{index}")) for index in range(self.config.num_labels)]
            if any(label != f"LABEL_{index}" for index, label in enumerate(labels)):
                return labels
        return [f"cell_{index}" for index in range(self.config.num_labels)]

    @can_return_tuple
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        labels: Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Tuple[Tensor, ...] | SequencePredictorOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            return_dict=True,
            **kwargs,
        )

        output = self.sequence_head(outputs, labels)
        logits, loss = output.logits, output.loss

        return SequencePredictorOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def postprocess(self, outputs: Any) -> Tensor:
        return torch.sigmoid(outputs["logits"])

ScBassetModel

Bases: ScBassetPreTrainedModel

Examples:

Python Console Session
1
2
3
4
5
6
7
8
>>> from multimolecule import ScBassetConfig, ScBassetModel, DnaTokenizer
>>> config = ScBassetConfig()
>>> model = ScBassetModel(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/scbasset")
>>> input = tokenizer(["ACGT" * 336, "TGCA" * 336], return_tensors="pt")
>>> output = model(**input)
>>> output["pooler_output"].shape
torch.Size([2, 32])
Source code in multimolecule/models/scbasset/modeling_scbasset.py
Python
class ScBassetModel(ScBassetPreTrainedModel):
    """
    Examples:
        >>> from multimolecule import ScBassetConfig, ScBassetModel, DnaTokenizer
        >>> config = ScBassetConfig()
        >>> model = ScBassetModel(config)
        >>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/scbasset")
        >>> input = tokenizer(["ACGT" * 336, "TGCA" * 336], return_tensors="pt")
        >>> output = model(**input)
        >>> output["pooler_output"].shape
        torch.Size([2, 32])
    """

    def __init__(self, config: ScBassetConfig):
        super().__init__(config)
        self.embeddings = ScBassetEmbedding(config)
        self.encoder = ScBassetEncoder(config)
        # Initialize weights and apply final processing
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: Tensor | NestedTensor | None = None,
        attention_mask: Tensor | None = None,
        inputs_embeds: Tensor | NestedTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutputWithPoolingAndCrossAttentions:
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        if input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if isinstance(input_ids, NestedTensor):
            if attention_mask is None:
                attention_mask = input_ids.mask
            input_ids = input_ids.tensor
        if isinstance(inputs_embeds, NestedTensor):
            if attention_mask is None:
                attention_mask = inputs_embeds.mask
            inputs_embeds = inputs_embeds.tensor

        embedding_output = self.embeddings(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
        )
        # The scBasset encoder collapses the sequence dimension through its dense bottleneck, so the final
        # bottleneck embedding is both the model's last hidden state and its pooled representation.
        sequence_output = self.encoder(embedding_output)

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=sequence_output,
        )

ScBassetPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in multimolecule/models/scbasset/modeling_scbasset.py
Python
class ScBassetPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = ScBassetConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _can_record_outputs: dict[str, Any] | None = None
    _no_split_modules = ["ScBassetConvLayer"]

    @torch.no_grad()
    def _init_weights(self, module: nn.Module):
        super()._init_weights(module)
        # Use transformers.initialization wrappers (imported as `init`); they check the
        # `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
        if isinstance(module, (nn.Conv1d, nn.Linear)):
            init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm, nn.GroupNorm)):
            init.ones_(module.weight)
            init.zeros_(module.bias)