跳转至

heads

heads 提供了一系列的模型预测头,用于处理不同的任务。

heads 接受 ModelOutuptdicttuple 作为输入。 它会自动查找预测所需的模型输出并相应地处理。

一些预测头可能需要额外的信息,例如 attention_maskinput_ids,例如 ContactPredictionHead。 这些额外的参数可以作为参数/关键字参数传入。

请注意,heads 使用与 🤗 Transformers 相同的 ModelOutupt 约定。 如果模型输出是一个 tuple,我们将第一个元素视为 pooler_output,第二个元素视为 last_hidden_state,最后一个元素视为 attention_map。 用户有责任确保模型输出格式正确。

如果模型输出是一个 ModelOutupt 或一个 dictheads 将从模型输出中查找 HeadConfig.output_name。 你可以在 HeadConfig 中指定 output_name,以确保 heads 可以正确定位所需的张量。

multimolecule.modules.heads.config

HeadConfig

Bases: BaseHeadConfig

Configuration class for a prediction head.

Parameters:

Name Type Description Default

num_labels

Number of labels to use in the last layer added to the model, typically for a classification task.

Head should look for Config.num_labels if is None.

required

problem_type

Problem type for XxxForYyyPrediction models. Can be one of "binary", "regression", "multiclass" or "multilabel".

Head should look for Config.problem_type if is None.

required

hidden_size

Dimensionality of the encoder layers and the pooler layer.

Head should look for Config.hidden_size if is None.

required

dropout

The dropout ratio for the hidden states.

required

transform

The transform operation applied to hidden states.

required

transform_act

The activation function of transform applied to hidden states.

required

bias

Whether to apply bias to the final prediction layer.

required

act

The activation function of the final prediction output.

required

layer_norm_eps

The epsilon used by the layer normalization layers.

required

output_name

The name of the tensor required in model outputs.

If is None, will use the default output name of the corresponding head.

required

type

The type of the head in the model.

This is used by [MultiMoleculeModel][multimolecule.MultiMoleculeModel] to construct heads.

required
Source code in multimolecule/modules/heads/config.py
Python
class HeadConfig(BaseHeadConfig):
    r"""
    Configuration class for a prediction head.

    Args:
        num_labels:
            Number of labels to use in the last layer added to the model, typically for a classification task.

            Head should look for [`Config.num_labels`][multimolecule.PreTrainedConfig] if is `None`.
        problem_type:
            Problem type for `XxxForYyyPrediction` models. Can be one of `"binary"`, `"regression"`,
            `"multiclass"` or `"multilabel"`.

            Head should look for [`Config.problem_type`][multimolecule.PreTrainedConfig] if is `None`.
        hidden_size:
            Dimensionality of the encoder layers and the pooler layer.

            Head should look for [`Config.hidden_size`][multimolecule.PreTrainedConfig] if is `None`.
        dropout:
            The dropout ratio for the hidden states.
        transform:
            The transform operation applied to hidden states.
        transform_act:
            The activation function of transform applied to hidden states.
        bias:
            Whether to apply bias to the final prediction layer.
        act:
            The activation function of the final prediction output.
        layer_norm_eps:
            The epsilon used by the layer normalization layers.
        output_name:
            The name of the tensor required in model outputs.

            If is `None`, will use the default output name of the corresponding head.
        type:
            The type of the head in the model.

            This is used by [`MultiMoleculeModel`][multimolecule.MultiMoleculeModel] to construct heads.
    """

    num_labels: Optional[int] = None
    problem_type: Optional[str] = None
    hidden_size: Optional[int] = None
    dropout: float = 0.0
    transform: Optional[str] = None
    transform_act: Optional[str] = "gelu"
    bias: bool = True
    act: Optional[str] = None
    layer_norm_eps: float = 1e-12
    output_name: Optional[str] = None
    type: Optional[str] = None

MaskedLMHeadConfig

Bases: BaseHeadConfig

Configuration class for a Masked Language Modeling head.

Parameters:

Name Type Description Default

hidden_size

Dimensionality of the encoder layers and the pooler layer.

Head should look for Config.hidden_size if is None.

required

dropout

The dropout ratio for the hidden states.

required

transform

The transform operation applied to hidden states.

required

transform_act

The activation function of transform applied to hidden states.

required

bias

Whether to apply bias to the final prediction layer.

required

act

The activation function of the final prediction output.

required

layer_norm_eps

The epsilon used by the layer normalization layers.

required

output_name

The name of the tensor required in model outputs.

If is None, will use the default output name of the corresponding head.

required
Source code in multimolecule/modules/heads/config.py
Python
class MaskedLMHeadConfig(BaseHeadConfig):
    r"""
    Configuration class for a Masked Language Modeling head.

    Args:
        hidden_size:
            Dimensionality of the encoder layers and the pooler layer.

            Head should look for [`Config.hidden_size`][multimolecule.PreTrainedConfig] if is `None`.
        dropout:
            The dropout ratio for the hidden states.
        transform:
            The transform operation applied to hidden states.
        transform_act:
            The activation function of transform applied to hidden states.
        bias:
            Whether to apply bias to the final prediction layer.
        act:
            The activation function of the final prediction output.
        layer_norm_eps:
            The epsilon used by the layer normalization layers.
        output_name:
            The name of the tensor required in model outputs.

            If is `None`, will use the default output name of the corresponding head.
    """

    hidden_size: Optional[int] = None
    dropout: float = 0.0
    transform: Optional[str] = "nonlinear"
    transform_act: Optional[str] = "gelu"
    bias: bool = True
    act: Optional[str] = None
    layer_norm_eps: float = 1e-12
    output_name: Optional[str] = None

multimolecule.modules.heads.sequence

SequencePredictionHead

Bases: PredictionHead

Head for tasks in sequence-level.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/modules/heads/sequence.py
Python
@HEADS.register("sequence")
class SequencePredictionHead(PredictionHead):
    r"""
    Head for tasks in sequence-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "pooler_output"
    r"""The default output to use for the head."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the SequencePredictionHead.

        Args:
            outputs: The outputs of the model.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[1]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
        return super().forward(output, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'pooler_output'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...], labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the SequencePredictionHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...]

The outputs of the model.

required
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/modules/heads/sequence.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the SequencePredictionHead.

    Args:
        outputs: The outputs of the model.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[1]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
    return super().forward(output, labels, **kwargs)

multimolecule.modules.heads.token

TokenPredictionHead

Bases: PredictionHead

Head for tasks in token-level.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/modules/heads/token.py
Python
@HEADS.token.register("single", default=True)
class TokenPredictionHead(PredictionHead):
    r"""
    Head for tasks in token-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the TokenPredictionHead.

        Args:
            outputs: The outputs of the model.
            attention_mask: The attention mask for the inputs.
            input_ids: The input ids for the inputs.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

        if attention_mask is None:
            attention_mask = self.get_attention_mask(input_ids)
        output, _, _ = self.remove_special_tokens(output, attention_mask, input_ids)

        return super().forward(output, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...], attention_mask: Tensor | None = None, input_ids: NestedTensor | Tensor | None = None, labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the TokenPredictionHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...]

The outputs of the model.

required
attention_mask
Tensor | None

The attention mask for the inputs.

None
input_ids
NestedTensor | Tensor | None

The input ids for the inputs.

None
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/modules/heads/token.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
    attention_mask: Tensor | None = None,
    input_ids: NestedTensor | Tensor | None = None,
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the TokenPredictionHead.

    Args:
        outputs: The outputs of the model.
        attention_mask: The attention mask for the inputs.
        input_ids: The input ids for the inputs.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[0]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

    if attention_mask is None:
        attention_mask = self.get_attention_mask(input_ids)
    output, _, _ = self.remove_special_tokens(output, attention_mask, input_ids)

    return super().forward(output, labels, **kwargs)

TokenKMerHead

Bases: PredictionHead

Head for tasks in token-level with kmer inputs.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/modules/heads/token.py
Python
@HEADS.token.register("kmer")
class TokenKMerHead(PredictionHead):
    r"""
    Head for tasks in token-level with kmer inputs.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.nmers = config.nmers

        # Do not pass bos_token_id and eos_token_id to unfold_kmer_embeddings
        # As they will be removed in preprocess
        self.unfold_kmer_embeddings = partial(unfold_kmer_embeddings, nmers=self.nmers)

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the TokenKMerHead.

        Args:
            outputs: The outputs of the model.
            attention_mask: The attention mask for the inputs.
            input_ids: The input ids for the inputs.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

        if attention_mask is None:
            attention_mask = self.get_attention_mask(input_ids)
        output, attention_mask, _ = self.remove_special_tokens(output, attention_mask, input_ids)

        output = self.unfold_kmer_embeddings(output, attention_mask)
        return super().forward(output, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...], attention_mask: Tensor | None = None, input_ids: NestedTensor | Tensor | None = None, labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the TokenKMerHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...]

The outputs of the model.

required
attention_mask
Tensor | None

The attention mask for the inputs.

None
input_ids
NestedTensor | Tensor | None

The input ids for the inputs.

None
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/modules/heads/token.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
    attention_mask: Tensor | None = None,
    input_ids: NestedTensor | Tensor | None = None,
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the TokenKMerHead.

    Args:
        outputs: The outputs of the model.
        attention_mask: The attention mask for the inputs.
        input_ids: The input ids for the inputs.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[0]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

    if attention_mask is None:
        attention_mask = self.get_attention_mask(input_ids)
    output, attention_mask, _ = self.remove_special_tokens(output, attention_mask, input_ids)

    output = self.unfold_kmer_embeddings(output, attention_mask)
    return super().forward(output, labels, **kwargs)

unfold_kmer_embeddings

Python
unfold_kmer_embeddings(embeddings: Tensor, attention_mask: Tensor, nmers: int, bos_token_id: int | None = None, eos_token_id: int | None = None) -> Tensor

Unfold k-mer embeddings to token embeddings.

For k-mer input, each embedding column represents k tokens. This should be fine for sequence level tasks, but sacrifices the resolution for token level tasks. This function unfolds the k-mer embeddings to token embeddings by sliding averaging the k-mer embeddings.

For example:

input tokens = ACGU

2-mer embeddings = [<CLS>, AC, CG, GU, <SEP>].

token embeddings = [<CLS>, AC, (AC + CG) / 2, (CG + GU) / 2, GU, <SEP>].

Parameters:

Name Type Description Default

embeddings

Tensor

The k-mer embeddings.

required

attention_mask

Tensor

The attention mask.

required

nmers

int

The number of tokens in each k-mer.

required

bos_token_id

int | None

The id of the beginning of sequence token. If not None, the first valid token will not be included in sliding averaging.

None

eos_token_id

int | None

The id of the end of sequence token. If not None, the last valid token will not be included in sliding averaging.

None

Returns:

Type Description
Tensor

The token embeddings.

Examples:

Python Console Session
>>> from danling import NestedTensor
>>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(5).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 3, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.0, 2.0, 3.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0]
>>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 4, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 7.0]
>>> embeddings = NestedTensor(torch.arange(7).repeat(2, 1).T, torch.arange(11).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 5, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0, 7.0, 8.0, 8.5, 9.0, 9.5, 10.0, 11.0]
>>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(4).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 2.5, 2.5, 2.5, 2.5, 3.0, 4.0]
>>> embeddings = NestedTensor(torch.arange(1).repeat(2, 1).T, torch.arange(2).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6)
>>> output[0, :, 0].tolist()
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 1.5, 1.5, 1.5, 1.5, 1.5, 2.0]
Source code in multimolecule/modules/heads/token.py
Python
def unfold_kmer_embeddings(
    embeddings: Tensor,
    attention_mask: Tensor,
    nmers: int,
    bos_token_id: int | None = None,
    eos_token_id: int | None = None,
) -> Tensor:
    r"""
    Unfold k-mer embeddings to token embeddings.

    For k-mer input, each embedding column represents k tokens.
    This should be fine for sequence level tasks, but sacrifices the resolution for token level tasks.
    This function unfolds the k-mer embeddings to token embeddings by sliding averaging the k-mer embeddings.

    For example:

    input tokens = `ACGU`

    2-mer embeddings = `[<CLS>, AC, CG, GU, <SEP>]`.

    token embeddings = `[<CLS>, AC, (AC + CG) / 2, (CG + GU) / 2, GU, <SEP>]`.

    Args:
        embeddings: The k-mer embeddings.
        attention_mask: The attention mask.
        nmers: The number of tokens in each k-mer.
        bos_token_id: The id of the beginning of sequence token.
            If not None, the first valid token will not be included in sliding averaging.
        eos_token_id: The id of the end of sequence token.
            If not None, the last valid token will not be included in sliding averaging.

    Returns:
        The token embeddings.

    Examples:
        >>> from danling import NestedTensor
        >>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(5).repeat(2, 1).T) + 1
        >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 3, True, True)
        >>> output[0, :, 0].tolist()
        [1.0, 2.0, 2.0, 2.0, 3.0, 0.0, 0.0]
        >>> output[1, :, 0].tolist()
        [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0]
        >>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1
        >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 4, True, True)
        >>> output[0, :, 0].tolist()
        [1.0, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0]
        >>> output[1, :, 0].tolist()
        [1.0, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 7.0]
        >>> embeddings = NestedTensor(torch.arange(7).repeat(2, 1).T, torch.arange(11).repeat(2, 1).T) + 1
        >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 5, True, True)
        >>> output[0, :, 0].tolist()
        [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0]
        >>> output[1, :, 0].tolist()
        [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0, 7.0, 8.0, 8.5, 9.0, 9.5, 10.0, 11.0]
        >>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(4).repeat(2, 1).T) + 1
        >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6, True, True)
        >>> output[0, :, 0].tolist()
        [1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.0]
        >>> output[1, :, 0].tolist()
        [1.0, 2.0, 2.5, 2.5, 2.5, 2.5, 2.5, 3.0, 4.0]
        >>> embeddings = NestedTensor(torch.arange(1).repeat(2, 1).T, torch.arange(2).repeat(2, 1).T) + 1
        >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6)
        >>> output[0, :, 0].tolist()
        [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0]
        >>> output[1, :, 0].tolist()
        [1.0, 1.5, 1.5, 1.5, 1.5, 1.5, 2.0]
    """

    batch_size, seq_length, hidden_size = embeddings.size()
    last_valid_indices = attention_mask.sum(dim=-1)
    output = torch.zeros(batch_size, seq_length + nmers - 1, hidden_size, device=embeddings.device)
    for index, (tensor, seq_length) in enumerate(zip(embeddings, last_valid_indices)):
        embedding = tensor[:seq_length]
        if bos_token_id is not None:
            embedding = embedding[1:]
        if eos_token_id is not None:
            embedding = embedding[:-1]
        if len(embedding) > nmers:
            begin = torch.stack([embedding[:i].mean(0) for i in range(1, nmers)])
            medium = embedding.unfold(0, nmers, 1).mean(-1)
            end = torch.stack([embedding[-i:].mean(0) for i in range(nmers - 1, 0, -1)])
            embedding = torch.cat([begin, medium, end])
        elif len(embedding) > 2:
            begin = torch.stack([embedding[:i].mean(0) for i in range(1, len(embedding))])
            end = torch.stack([embedding[-i:].mean(0) for i in range(nmers, 0, -1)])
            embedding = torch.cat([begin, end])
        elif len(embedding) == 2:
            medium = embedding.mean(0).repeat(nmers - 1, 1)
            embedding = torch.cat([embedding[0][None, :], medium, embedding[1][None, :]])
        elif len(embedding) == 1:
            embedding = embedding.repeat(nmers, 1)
        else:
            raise ValueError("Sequence length is less than nmers.")
        if bos_token_id is not None:
            embedding = torch.cat([tensor[0][None, :], embedding])
        if eos_token_id is not None:
            embedding = torch.cat([embedding, tensor[seq_length - 1][None, :]])
        output[index, : seq_length + nmers - 1] = embedding
    return output

multimolecule.modules.heads.contact

ContactPredictionHead

Bases: BasePredictionHead

Head for tasks in contact-level.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> import torch
>>> from multimolecule.models import PreTrainedConfig
>>> from multimolecule.modules.heads import ContactPredictionHead
>>> config = PreTrainedConfig(hidden_size=8)
>>> head = ContactPredictionHead(config)
>>> input = torch.randn(1, 28, config.hidden_size)
>>> output = head({"last_hidden_state": input}, attention_mask=torch.ones(1, 28))
Source code in multimolecule/modules/heads/contact.py
Python
@HEADS.contact.logits.register("linear", default=True)
class ContactPredictionHead(BasePredictionHead):
    r"""
    Head for tasks in contact-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.

    Examples:
        >>> import torch
        >>> from multimolecule.models import PreTrainedConfig
        >>> from multimolecule.modules.heads import ContactPredictionHead
        >>> config = PreTrainedConfig(hidden_size=8)
        >>> head = ContactPredictionHead(config)
        >>> input = torch.randn(1, 28, config.hidden_size)
        >>> output = head({"last_hidden_state": input}, attention_mask=torch.ones(1, 28))
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    require_attentions: bool = False
    r"""Whether the head requires attentions."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        out_channels: int = self.config.hidden_size  # type: ignore[assignment]
        self.dropout = nn.Dropout(self.config.dropout)
        self.transform = HEAD_TRANSFORMS_HF.build(self.config)
        self.decoder = nn.Linear(out_channels, self.num_labels, bias=self.config.bias)
        self.activation = ACT2FN[self.config.act] if self.config.act is not None else None
        self.criterion = CRITERIONS.build(self.config)

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

        if attention_mask is None:
            attention_mask = self.get_attention_mask(input_ids)
        output, _, _ = self.remove_special_tokens(output, attention_mask, input_ids)

        output = self.dropout(output)
        output = self.transform(output)
        contact_map = output.unsqueeze(1) * output.unsqueeze(2)
        contact_map = self.decoder(contact_map)
        contact_map = self.symmetrize(contact_map)
        if self.activation is not None:
            contact_map = self.activation(contact_map)

        if labels is not None:
            if isinstance(labels, NestedTensor):
                if isinstance(contact_map, Tensor):
                    contact_map = labels.nested_like(contact_map, strict=False)
                return HeadOutput(contact_map, self.criterion(contact_map.concat, labels.concat))
            return HeadOutput(contact_map, self.criterion(contact_map, labels))
        return HeadOutput(contact_map)

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

require_attentions class-attribute instance-attribute

Python
require_attentions: bool = False

Whether the head requires attentions.

ContactPredictionResNetHead

Bases: ContactPredictionHead

Head for tasks in contact-level.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> import torch
>>> from multimolecule.models import PreTrainedConfig
>>> from multimolecule.modules.heads import ContactPredictionResNetHead
>>> config = PreTrainedConfig(hidden_size=32)
>>> head = ContactPredictionResNetHead(config)
>>> input = torch.randn(1, 28, config.hidden_size)
>>> output = head({"last_hidden_state": input}, attention_mask=torch.ones(1, 28))
Source code in multimolecule/modules/heads/contact.py
Python
@HEADS.contact.logits.register("resnet")
class ContactPredictionResNetHead(ContactPredictionHead):
    r"""
    Head for tasks in contact-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.

    Examples:
        >>> import torch
        >>> from multimolecule.models import PreTrainedConfig
        >>> from multimolecule.modules.heads import ContactPredictionResNetHead
        >>> config = PreTrainedConfig(hidden_size=32)
        >>> head = ContactPredictionResNetHead(config)
        >>> input = torch.randn(1, 28, config.hidden_size)
        >>> output = head({"last_hidden_state": input}, attention_mask=torch.ones(1, 28))
    """

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.decoder = ResNet(
            num_layers=self.config.get("num_layers", 6),
            hidden_size=self.config.hidden_size,  # type: ignore[arg-type]
            block=self.config.get("block", "auto"),
            num_channels=self.config.get("num_channels"),
            num_labels=self.num_labels,
        )

ContactPredictionUNetHead

Bases: ContactPredictionHead

Head for tasks in contact-level.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> import torch
>>> from multimolecule.models import PreTrainedConfig
>>> from multimolecule.modules.heads import ContactPredictionUNetHead
>>> config = PreTrainedConfig(hidden_size=32)
>>> head = ContactPredictionUNetHead(config)
>>> input = torch.randn(1, 28, config.hidden_size)
>>> output = head({"last_hidden_state": input}, attention_mask=torch.ones(1, 28))
Source code in multimolecule/modules/heads/contact.py
Python
@HEADS.contact.logits.register("unet")
class ContactPredictionUNetHead(ContactPredictionHead):
    r"""
    Head for tasks in contact-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.

    Examples:
        >>> import torch
        >>> from multimolecule.models import PreTrainedConfig
        >>> from multimolecule.modules.heads import ContactPredictionUNetHead
        >>> config = PreTrainedConfig(hidden_size=32)
        >>> head = ContactPredictionUNetHead(config)
        >>> input = torch.randn(1, 28, config.hidden_size)
        >>> output = head({"last_hidden_state": input}, attention_mask=torch.ones(1, 28))
    """

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.decoder = UNet(
            num_layers=self.config.get("num_layers", 6),
            hidden_size=self.config.hidden_size,  # type: ignore[arg-type]
            block=self.config.get("block", "auto"),
            num_channels=self.config.get("num_channels"),
            num_labels=self.num_labels,
        )

ContactAttentionHead

Bases: BasePredictionHead

Head for tasks in contact-level.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> import torch
>>> from multimolecule.models import PreTrainedConfig
>>> from multimolecule.modules.heads import ContactAttentionHead
>>> config = PreTrainedConfig(num_hidden_layers=2, num_attention_heads=4)
>>> head = ContactAttentionHead(config)
>>> input = tuple(torch.randn(1, config.num_attention_heads, 28, 28) for _ in range(config.num_hidden_layers))
>>> output = head({"attentions": input}, attention_mask=torch.ones(1, 28))
Source code in multimolecule/modules/heads/contact.py
Python
@HEADS.contact.attention.register("linear")
class ContactAttentionHead(BasePredictionHead):
    r"""
    Head for tasks in contact-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.

    Examples:
        >>> import torch
        >>> from multimolecule.models import PreTrainedConfig
        >>> from multimolecule.modules.heads import ContactAttentionHead
        >>> config = PreTrainedConfig(num_hidden_layers=2, num_attention_heads=4)
        >>> head = ContactAttentionHead(config)
        >>> input = tuple(torch.randn(1, config.num_attention_heads, 28, 28) for _ in range(config.num_hidden_layers))
        >>> output = head({"attentions": input}, attention_mask=torch.ones(1, 28))
    """

    output_name: str = "attentions"
    r"""The default output to use for the head."""

    require_attentions: bool = True
    r"""Whether the head requires attentions."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        if head_config is None:
            head_config = HeadConfig(hidden_size=config.num_hidden_layers * config.num_attention_heads)
        else:
            head_config.hidden_size = config.num_hidden_layers * config.num_attention_heads
        super().__init__(config, head_config)
        self.dropout = nn.Dropout(self.config.dropout)
        self.transform = HEAD_TRANSFORMS_HF.build(self.config)
        self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias)
        self.activation = ACT2FN[self.config.act] if self.config.act is not None else None
        self.criterion = CRITERIONS.build(self.config)

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Mapping | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[-1]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

        if isinstance(output, (list, tuple)):
            output = torch.stack(output, 1)
        contact_map = output.flatten(1, 2).permute(0, 2, 3, 1)

        if attention_mask is None:
            attention_mask = self.get_attention_mask(input_ids)
        contact_map, _, _ = self.remove_special_tokens_2d(contact_map, attention_mask, input_ids)

        contact_map = self.dropout(contact_map)
        contact_map = self.transform(contact_map)
        contact_map = self.decoder(contact_map)
        contact_map = self.symmetrize(contact_map)
        if self.activation is not None:
            contact_map = self.activation(contact_map)

        if labels is not None:
            if isinstance(labels, NestedTensor):
                if isinstance(contact_map, Tensor):
                    contact_map = labels.nested_like(contact_map, strict=False)
                return HeadOutput(contact_map, self.criterion(contact_map.concat, labels.concat))
            return HeadOutput(contact_map, self.criterion(contact_map, labels))
        return HeadOutput(contact_map)

output_name class-attribute instance-attribute

Python
output_name: str = 'attentions'

The default output to use for the head.

require_attentions class-attribute instance-attribute

Python
require_attentions: bool = True

Whether the head requires attentions.

ContactAttentionResNetHead

Bases: ContactAttentionHead

Head for tasks in contact-level.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> import torch
>>> from multimolecule.models import PreTrainedConfig
>>> from multimolecule.modules.heads import ContactAttentionResNetHead
>>> config = PreTrainedConfig(num_hidden_layers=8, num_attention_heads=4)
>>> head = ContactAttentionResNetHead(config)
>>> input = tuple(torch.randn(1, config.num_attention_heads, 28, 28) for _ in range(config.num_hidden_layers))
>>> output = head({"attentions": input}, attention_mask=torch.ones(1, 28))
Source code in multimolecule/modules/heads/contact.py
Python
@HEADS.contact.attention.register("resnet")
class ContactAttentionResNetHead(ContactAttentionHead):
    r"""
    Head for tasks in contact-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.

    Examples:
        >>> import torch
        >>> from multimolecule.models import PreTrainedConfig
        >>> from multimolecule.modules.heads import ContactAttentionResNetHead
        >>> config = PreTrainedConfig(num_hidden_layers=8, num_attention_heads=4)
        >>> head = ContactAttentionResNetHead(config)
        >>> input = tuple(torch.randn(1, config.num_attention_heads, 28, 28) for _ in range(config.num_hidden_layers))
        >>> output = head({"attentions": input}, attention_mask=torch.ones(1, 28))
    """

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.decoder = ResNet(
            num_layers=self.config.get("num_layers", 16),
            hidden_size=self.config.hidden_size,  # type: ignore[arg-type]
            block=self.config.get("block", "auto"),
            num_channels=self.config.get("num_channels"),
            num_labels=self.num_labels,
        )

ContactAttentionUNetHead

Bases: ContactAttentionHead

Head for tasks in contact-level.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None

Examples:

Python Console Session
1
2
3
4
5
6
7
>>> import torch
>>> from multimolecule.models import PreTrainedConfig
>>> from multimolecule.modules.heads import ContactAttentionUNetHead
>>> config = PreTrainedConfig(num_hidden_layers=4, num_attention_heads=8)
>>> head = ContactAttentionUNetHead(config)
>>> input = tuple(torch.randn(1, config.num_attention_heads, 28, 28) for _ in range(config.num_hidden_layers))
>>> output = head({"attentions": input}, attention_mask=torch.ones(1, 28))
Source code in multimolecule/modules/heads/contact.py
Python
@HEADS.contact.attention.register("unet")
class ContactAttentionUNetHead(ContactAttentionHead):
    r"""
    Head for tasks in contact-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.

    Examples:
        >>> import torch
        >>> from multimolecule.models import PreTrainedConfig
        >>> from multimolecule.modules.heads import ContactAttentionUNetHead
        >>> config = PreTrainedConfig(num_hidden_layers=4, num_attention_heads=8)
        >>> head = ContactAttentionUNetHead(config)
        >>> input = tuple(torch.randn(1, config.num_attention_heads, 28, 28) for _ in range(config.num_hidden_layers))
        >>> output = head({"attentions": input}, attention_mask=torch.ones(1, 28))
    """

    output_name: str = "attentions"
    r"""The default output to use for the head."""

    require_attentions: bool = True
    r"""Whether the head requires attentions."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.decoder = UNet(
            num_layers=self.config.get("num_layers", 4),
            hidden_size=self.config.hidden_size,  # type: ignore[arg-type]
            block=self.config.get("block", "auto"),
            num_channels=self.config.get("num_channels"),
            num_labels=self.num_labels,
        )

output_name class-attribute instance-attribute

Python
output_name: str = 'attentions'

The default output to use for the head.

require_attentions class-attribute instance-attribute

Python
require_attentions: bool = True

Whether the head requires attentions.

multimolecule.modules.heads.pretrain

MaskedLMHead

Bases: BasePredictionHead

Head for masked language modeling.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

MaskedLMHeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/modules/heads/pretrain.py
Python
@HEADS.register("masked_lm")
class MaskedLMHead(BasePredictionHead):
    r"""
    Head for masked language modeling.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    def __init__(
        self, config: PreTrainedConfig, weight: Tensor | None = None, head_config: MaskedLMHeadConfig | None = None
    ):
        if head_config is None:
            head_config = (config.lm_head if hasattr(config, "lm_head") else config.head) or MaskedLMHeadConfig()
        head_config.num_labels = config.vocab_size
        super().__init__(config, head_config)
        self.dropout = nn.Dropout(self.config.dropout)
        self.transform = HEAD_TRANSFORMS_HF.build(self.config)
        self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=False)
        if weight is not None:
            self.decoder.weight = weight
        if self.config.bias:
            self.bias = nn.Parameter(torch.zeros(self.num_labels))
            self.decoder.bias = self.bias
        self.activation = ACT2FN[self.config.act] if self.config.act is not None else None

    def forward(
        self,
        outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
        labels: Tensor | None = None,
        output_name: str | None = None,
    ) -> HeadOutput:
        r"""
        Forward pass of the MaskedLMHead.

        Args:
            outputs: The outputs of the model.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

        output = self.dropout(output)
        output = self.transform(output)
        output = self.decoder(output)
        if self.activation is not None:
            output = self.activation(output)

        if labels is not None:
            if isinstance(labels, NestedTensor):
                if isinstance(output, Tensor):
                    output = labels.nested_like(output, strict=False)
                return HeadOutput(output, F.cross_entropy(output.concat, labels.concat))
            return HeadOutput(output, F.cross_entropy(output.view(-1, self.num_labels), labels.view(-1)))
        return HeadOutput(output)

    def _tie_weights(self):
        self.decoder.bias = self.bias

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...], labels: Tensor | None = None, output_name: str | None = None) -> HeadOutput

Forward pass of the MaskedLMHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...]

The outputs of the model.

required
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/modules/heads/pretrain.py
Python
def forward(
    self,
    outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
    labels: Tensor | None = None,
    output_name: str | None = None,
) -> HeadOutput:
    r"""
    Forward pass of the MaskedLMHead.

    Args:
        outputs: The outputs of the model.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[0]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

    output = self.dropout(output)
    output = self.transform(output)
    output = self.decoder(output)
    if self.activation is not None:
        output = self.activation(output)

    if labels is not None:
        if isinstance(labels, NestedTensor):
            if isinstance(output, Tensor):
                output = labels.nested_like(output, strict=False)
            return HeadOutput(output, F.cross_entropy(output.concat, labels.concat))
        return HeadOutput(output, F.cross_entropy(output.view(-1, self.num_labels), labels.view(-1)))
    return HeadOutput(output)

multimolecule.modules.heads.generic

BasePredictionHead

Bases: Module

Head for all-level of tasks.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/modules/heads/generic.py
Python
class BasePredictionHead(nn.Module):
    r"""
    Head for all-level of tasks.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    config: HeadConfig
    r"""The configuration object for the head."""

    num_labels: int
    r"""Number of labels for the head."""

    output_name: str | None
    r"""The default output to use for the head."""

    require_attentions: bool = False
    r"""Whether the head requires attentions from the model."""

    bos_token_id: int | None = None
    r"""The ID of the beginning-of-sequence token. Usually is an alias of `cls_token_id`."""

    pad_token_id: int | None = None
    r"""The ID of the padding token."""

    eos_token_id: int | None = None
    r"""The ID of the end-of-sequence token. In rare cases, it is an alias of `sep_token_id`."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__()
        if head_config is None:
            head_config = config.head or HeadConfig(num_labels=config.num_labels)
        if not isinstance(head_config, HeadConfig):
            head_config = HeadConfig(head_config)
        if not head_config.num_labels:
            head_config.num_labels = config.num_labels
        if not head_config.hidden_size:
            head_config.hidden_size = config.hidden_size
        if not head_config.problem_type:
            head_config.problem_type = config.problem_type
        self.config = head_config
        self.bos_token_id = config.bos_token_id
        self.eos_token_id = config.eos_token_id
        self.pad_token_id = config.pad_token_id
        self.num_labels = self.config.num_labels  # type: ignore[assignment]
        if getattr(self.config, "output_name", None) is not None:
            self.output_name = self.config.output_name

    def get_attention_mask(self, input_ids: NestedTensor | Tensor) -> Tensor:
        r"""
        Generate attention mask from input IDs or extract from NestedTensor.

        Creates a binary attention mask indicating which tokens should be attended to (1)
        and which should be ignored (0, typically padding tokens). For NestedTensor inputs,
        extracts the pre-computed mask. For regular tensors, compares against pad_token_id.

        Args:
            input_ids: Input token IDs as either a NestedTensor with embedded mask
                or a regular Tensor of shape `(batch_size, seq_len)`.

        Returns:
            Binary attention mask of shape `(batch_size, seq_len)` where 1 indicates
            tokens to attend to and 0 indicates tokens to ignore.

        Raises:
            ValueError: If input_ids is None or if pad_token_id is None when needed
                for regular Tensor inputs.

        Examples:
            >>> import torch
            >>> from multimolecule.models.configuration_utils import PreTrainedConfig
            >>> from multimolecule.modules.heads.generic import BasePredictionHead
            >>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=128))
            >>> input_ids = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 0, 0, 0]])
            >>> mask = head.get_attention_mask(input_ids)
            >>> mask
            tensor([[1, 1, 1, 0, 0],
                    [1, 1, 0, 0, 0]], dtype=torch.int32)
        """
        if isinstance(input_ids, NestedTensor):
            return input_ids.mask
        if input_ids is None:
            raise ValueError(
                f"Unable to infer attention mask for {self.__class__.__name__}, because input_ids is None."
            )
        if self.pad_token_id is None:
            raise ValueError(
                f"Unable to infer attention mask for {self.__class__.__name__}, because pad_token_id is None."
            )
        return input_ids.ne(self.pad_token_id).int()

    def remove_special_tokens(
        self, output: Tensor, attention_mask: Tensor, input_ids: Tensor | None = None
    ) -> Tuple[Tensor, Tensor, Tensor]:
        r"""
        Remove special tokens and clean up model outputs using attention masks.

        Processes model outputs by removing special tokens that were added during tokenization
        and applies attention masking to zero out padding positions. This comprehensive cleanup
        is essential for sequence-level tasks where predictions should only cover the actual
        input sequence, excluding special tokens and padding.

        The method performs:
        - BOS token removal: Strips the first token from all sequences
        - EOS token removal: Strips tokens after the EOS token and the EOS token itself
        - Attention mask adjustment: Updates mask to match the trimmed sequences
        - Output cleanup: Multiplies output by attention mask to zero out padding positions

        Args:
            output: Model output tensor of shape `(batch_size, seq_len, hidden_size)`.
            attention_mask: Attention mask of shape `(batch_size, seq_len)`.
            input_ids: Optional input token IDs of shape `(batch_size, seq_len)`.
                Used for precise EOS token location when available.

        Returns:
            Tuple containing:
                - output: Cleaned output tensor with special tokens removed and padding zeroed
                - attention_mask: Updated attention mask matching trimmed sequences
                - input_ids: Trimmed input IDs (if provided) or unchanged input

        Examples:
            >>> import torch
            >>> from multimolecule.models.configuration_utils import PreTrainedConfig
            >>> from multimolecule.modules.heads.generic import BasePredictionHead
            >>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=4))
            >>> output = torch.randn(1, 6, 4)
            >>> attention_mask = torch.tensor([[1, 1, 1, 1, 0, 0]])
            >>> input_ids = torch.tensor([[1, 10, 20, 2, 0, 0]])
            >>> new_out, new_mask, new_ids = head.remove_special_tokens(output, attention_mask, input_ids)
            >>> output.shape[1], new_out.shape[1]
            (6, 4)
            >>> new_mask
            tensor([[1, 1, 0, 0]])
        """
        if self.bos_token_id is not None:
            output = output[..., 1:, :]
            attention_mask = attention_mask[..., 1:]
            if input_ids is not None:
                input_ids = input_ids[..., 1:]
        if self.eos_token_id is not None:
            if input_ids is not None:
                eos_mask = input_ids.ne(self.eos_token_id).to(output.device)
                if isinstance(input_ids, Tensor):
                    input_ids.masked_fill_(~eos_mask, self.pad_token_id or 0)
                if isinstance(eos_mask, NestedTensor):
                    eos_mask = eos_mask.tensor
                input_ids = input_ids[..., :-1]
            else:
                last_valid_indices = attention_mask.sum(dim=-1) - 1
                seq_length = attention_mask.size(-1)
                eos_mask = torch.arange(seq_length, device=output.device) != last_valid_indices.unsqueeze(1)
            output = (output * eos_mask.unsqueeze(-1))[..., :-1, :]
            attention_mask = (attention_mask * eos_mask)[..., :-1]
        if attention_mask is not None:
            output = output * attention_mask.unsqueeze(-1)
        return output, attention_mask, input_ids

    def remove_special_tokens_2d(
        self, output: Tensor, attention_mask: Tensor, input_ids: Tensor | None = None
    ) -> Tuple[Tensor, Tensor, Tensor]:
        r"""
        Remove special tokens from 2D outputs like contact maps or pairwise interaction matrices.

        Extends `remove_special_tokens` to handle 2D outputs where both sequence dimensions
        need special token removal. This is crucial for contact prediction and structure
        analysis tasks where the output represents pairwise relationships between residues.

        The method removes:
        - BOS tokens: Strips first row and column from the 2D output
        - EOS tokens: Strips rows/columns after EOS positions and the EOS positions themselves
        - Updates attention mask: Creates 2D mask from 1D sequence mask

        Args:
            output: 2D model output of shape `(batch_size, seq_len, seq_len, channels)`.
            attention_mask: 1D attention mask of shape `(batch_size, seq_len)`.
            input_ids: Optional input token IDs of shape `(batch_size, seq_len)`.

        Returns:
            Tuple containing:
                - output: Trimmed 2D output with special tokens removed from both dimensions
                - attention_mask: 2D attention mask of shape `(batch_size, new_len, new_len)`
                - input_ids: Trimmed input IDs (if provided) or unchanged input

        Examples:
            >>> import torch
            >>> from multimolecule.models.configuration_utils import PreTrainedConfig
            >>> from multimolecule.modules.heads.generic import BasePredictionHead
            >>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=4))
            >>> output = torch.randn(1, 5, 5, 1)
            >>> input_ids = torch.tensor([[1, 10, 20, 2, 0]])
            >>> attention_mask = torch.tensor([[1, 1, 1, 1, 0]])
            >>> new_out, new_mask, new_ids = head.remove_special_tokens_2d(output, attention_mask, input_ids)
            >>> output.shape, new_out.shape
            (torch.Size([1, 5, 5, 1]), torch.Size([1, 3, 3, 1]))
            >>> new_mask.shape
            torch.Size([1, 3, 3])
        """
        if self.bos_token_id is not None:
            output = output[..., 1:, 1:, :]
            attention_mask = attention_mask[..., 1:]
            if input_ids is not None:
                input_ids = input_ids[..., 1:]
        if self.eos_token_id is not None:
            if input_ids is not None:
                eos_mask = input_ids.ne(self.eos_token_id).to(output.device)
                if isinstance(input_ids, Tensor):
                    input_ids.masked_fill_(~eos_mask, self.pad_token_id or 0)
                if isinstance(eos_mask, NestedTensor):
                    eos_mask = eos_mask.tensor
                input_ids = input_ids[..., :-1]
            else:
                last_valid_indices = attention_mask.sum(dim=-1) - 1
                seq_length = attention_mask.size(-1)
                eos_mask = torch.arange(seq_length, device=output.device) != last_valid_indices.unsqueeze(1)
            attention_mask = (attention_mask * eos_mask)[..., :-1]
            eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
            output = (output * eos_mask.unsqueeze(-1))[..., :-1, :-1, :]
        attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
        if attention_mask is not None:
            output = output * attention_mask.unsqueeze(-1)
        return output, attention_mask, input_ids

    @staticmethod
    def symmetrize(x: Tensor) -> Tensor:
        r"""
        Make output symmetric by averaging the tensor with its transpose.

        Args:
            x: Input tensor of shape (batch_size, seq_len, seq_len, channels).

        Returns:
            Symmetric tensor with the same shape as input.

        Examples:
            >>> import torch
            >>> from multimolecule.modules.heads.generic import BasePredictionHead
            >>> x = torch.tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]])
            >>> x.squeeze(-1)
            tensor([[[1., 2.],
                     [3., 4.]]])
            >>> symmetric = BasePredictionHead.symmetrize(x)
            >>> symmetric.squeeze(-1)
            tensor([[[1.0000, 2.5000],
                     [2.5000, 4.0000]]])
            >>> torch.allclose(symmetric, symmetric.transpose(1, 2))
            True
        """
        return (x + x.transpose(1, 2)) / 2

    @staticmethod
    def average_product_correct(x: Tensor) -> Tensor:
        r"""Perform Average Product Correction (APC) to remove systematic biases from contact maps.

        APC removes row and column biases that arise from varying residue frequencies and
        structural preferences in molecular contact maps. It subtracts the expected contact
        probability based on marginal frequencies to reveal genuine structural interactions.

        The correction formula: `corrected = original - (row_sums × col_sums) / total_sum`

        This is essential for accurate contact prediction across DNA, RNA, and protein structures.

        Args:
            x: Contact map tensor of shape `(batch_size, seq_len, seq_len, channels)`

        Returns:
            Bias-corrected contact map with the same shape as input

        Note:
            This correction removes spurious correlations caused by sequence composition bias,
            making genuine molecular contacts stand out more clearly from background noise.

        Examples:
            >>> import torch
            >>> from multimolecule.modules.heads.generic import BasePredictionHead
            >>> x = torch.tensor([[[[0.8, 0.6], [0.7, 0.5]], [[0.6, 0.4], [0.5, 0.3]]]])
            >>> x.squeeze(-1)
            tensor([[[[0.8000, 0.6000],
                      [0.7000, 0.5000]],
            <BLANKLINE>
                     [[0.6000, 0.4000],
                      [0.5000, 0.3000]]]])
            >>> corrected = BasePredictionHead.average_product_correct(x)
            >>> corrected.squeeze(-1)
            tensor([[[[-0.0077, -0.0111],
                      [ 0.0077,  0.0111]],
            <BLANKLINE>
                     [[ 0.0077,  0.0111],
                      [-0.0077, -0.0111]]]])
            >>> row_sums = corrected.sum(dim=2).squeeze()
            >>> col_sums = corrected.sum(dim=1).squeeze()
            >>> torch.allclose(row_sums, torch.tensor([[0.0, 0.0], [-0.0, -0.0]]), atol=1e-6)
            True
            >>> torch.allclose(col_sums, torch.tensor([[0.0, 0.0], [-0.0, -0.0]]), atol=1e-6)
            True
        """
        return x - x.sum(1, keepdims=True) * x.sum(2, keepdims=True) / x.sum((1, 2), keepdims=True)

output_name instance-attribute

Python
output_name: str | None

The default output to use for the head.

require_attentions class-attribute instance-attribute

Python
require_attentions: bool = False

Whether the head requires attentions from the model.

config instance-attribute

Python
config: HeadConfig = head_config

The configuration object for the head.

bos_token_id class-attribute instance-attribute

Python
bos_token_id: int | None = bos_token_id

The ID of the beginning-of-sequence token. Usually is an alias of cls_token_id.

eos_token_id class-attribute instance-attribute

Python
eos_token_id: int | None = eos_token_id

The ID of the end-of-sequence token. In rare cases, it is an alias of sep_token_id.

pad_token_id class-attribute instance-attribute

Python
pad_token_id: int | None = pad_token_id

The ID of the padding token.

num_labels instance-attribute

Python
num_labels: int = num_labels

Number of labels for the head.

get_attention_mask

Python
get_attention_mask(input_ids: NestedTensor | Tensor) -> Tensor

Generate attention mask from input IDs or extract from NestedTensor.

Creates a binary attention mask indicating which tokens should be attended to (1) and which should be ignored (0, typically padding tokens). For NestedTensor inputs, extracts the pre-computed mask. For regular tensors, compares against pad_token_id.

Parameters:

Name Type Description Default
input_ids
NestedTensor | Tensor

Input token IDs as either a NestedTensor with embedded mask or a regular Tensor of shape (batch_size, seq_len).

required

Returns:

Type Description
Tensor

Binary attention mask of shape (batch_size, seq_len) where 1 indicates

Tensor

tokens to attend to and 0 indicates tokens to ignore.

Raises:

Type Description
ValueError

If input_ids is None or if pad_token_id is None when needed for regular Tensor inputs.

Examples:

Python Console Session
1
2
3
4
5
6
7
8
9
>>> import torch
>>> from multimolecule.models.configuration_utils import PreTrainedConfig
>>> from multimolecule.modules.heads.generic import BasePredictionHead
>>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=128))
>>> input_ids = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 0, 0, 0]])
>>> mask = head.get_attention_mask(input_ids)
>>> mask
tensor([[1, 1, 1, 0, 0],
        [1, 1, 0, 0, 0]], dtype=torch.int32)
Source code in multimolecule/modules/heads/generic.py
Python
def get_attention_mask(self, input_ids: NestedTensor | Tensor) -> Tensor:
    r"""
    Generate attention mask from input IDs or extract from NestedTensor.

    Creates a binary attention mask indicating which tokens should be attended to (1)
    and which should be ignored (0, typically padding tokens). For NestedTensor inputs,
    extracts the pre-computed mask. For regular tensors, compares against pad_token_id.

    Args:
        input_ids: Input token IDs as either a NestedTensor with embedded mask
            or a regular Tensor of shape `(batch_size, seq_len)`.

    Returns:
        Binary attention mask of shape `(batch_size, seq_len)` where 1 indicates
        tokens to attend to and 0 indicates tokens to ignore.

    Raises:
        ValueError: If input_ids is None or if pad_token_id is None when needed
            for regular Tensor inputs.

    Examples:
        >>> import torch
        >>> from multimolecule.models.configuration_utils import PreTrainedConfig
        >>> from multimolecule.modules.heads.generic import BasePredictionHead
        >>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=128))
        >>> input_ids = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 0, 0, 0]])
        >>> mask = head.get_attention_mask(input_ids)
        >>> mask
        tensor([[1, 1, 1, 0, 0],
                [1, 1, 0, 0, 0]], dtype=torch.int32)
    """
    if isinstance(input_ids, NestedTensor):
        return input_ids.mask
    if input_ids is None:
        raise ValueError(
            f"Unable to infer attention mask for {self.__class__.__name__}, because input_ids is None."
        )
    if self.pad_token_id is None:
        raise ValueError(
            f"Unable to infer attention mask for {self.__class__.__name__}, because pad_token_id is None."
        )
    return input_ids.ne(self.pad_token_id).int()

remove_special_tokens

Python
remove_special_tokens(output: Tensor, attention_mask: Tensor, input_ids: Tensor | None = None) -> Tuple[Tensor, Tensor, Tensor]

Remove special tokens and clean up model outputs using attention masks.

Processes model outputs by removing special tokens that were added during tokenization and applies attention masking to zero out padding positions. This comprehensive cleanup is essential for sequence-level tasks where predictions should only cover the actual input sequence, excluding special tokens and padding.

The method performs: - BOS token removal: Strips the first token from all sequences - EOS token removal: Strips tokens after the EOS token and the EOS token itself - Attention mask adjustment: Updates mask to match the trimmed sequences - Output cleanup: Multiplies output by attention mask to zero out padding positions

Parameters:

Name Type Description Default
output
Tensor

Model output tensor of shape (batch_size, seq_len, hidden_size).

required
attention_mask
Tensor

Attention mask of shape (batch_size, seq_len).

required
input_ids
Tensor | None

Optional input token IDs of shape (batch_size, seq_len). Used for precise EOS token location when available.

None

Returns:

Type Description
Tuple[Tensor, Tensor, Tensor]

Tuple containing: - output: Cleaned output tensor with special tokens removed and padding zeroed - attention_mask: Updated attention mask matching trimmed sequences - input_ids: Trimmed input IDs (if provided) or unchanged input

Examples:

Python Console Session
>>> import torch
>>> from multimolecule.models.configuration_utils import PreTrainedConfig
>>> from multimolecule.modules.heads.generic import BasePredictionHead
>>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=4))
>>> output = torch.randn(1, 6, 4)
>>> attention_mask = torch.tensor([[1, 1, 1, 1, 0, 0]])
>>> input_ids = torch.tensor([[1, 10, 20, 2, 0, 0]])
>>> new_out, new_mask, new_ids = head.remove_special_tokens(output, attention_mask, input_ids)
>>> output.shape[1], new_out.shape[1]
(6, 4)
>>> new_mask
tensor([[1, 1, 0, 0]])
Source code in multimolecule/modules/heads/generic.py
Python
def remove_special_tokens(
    self, output: Tensor, attention_mask: Tensor, input_ids: Tensor | None = None
) -> Tuple[Tensor, Tensor, Tensor]:
    r"""
    Remove special tokens and clean up model outputs using attention masks.

    Processes model outputs by removing special tokens that were added during tokenization
    and applies attention masking to zero out padding positions. This comprehensive cleanup
    is essential for sequence-level tasks where predictions should only cover the actual
    input sequence, excluding special tokens and padding.

    The method performs:
    - BOS token removal: Strips the first token from all sequences
    - EOS token removal: Strips tokens after the EOS token and the EOS token itself
    - Attention mask adjustment: Updates mask to match the trimmed sequences
    - Output cleanup: Multiplies output by attention mask to zero out padding positions

    Args:
        output: Model output tensor of shape `(batch_size, seq_len, hidden_size)`.
        attention_mask: Attention mask of shape `(batch_size, seq_len)`.
        input_ids: Optional input token IDs of shape `(batch_size, seq_len)`.
            Used for precise EOS token location when available.

    Returns:
        Tuple containing:
            - output: Cleaned output tensor with special tokens removed and padding zeroed
            - attention_mask: Updated attention mask matching trimmed sequences
            - input_ids: Trimmed input IDs (if provided) or unchanged input

    Examples:
        >>> import torch
        >>> from multimolecule.models.configuration_utils import PreTrainedConfig
        >>> from multimolecule.modules.heads.generic import BasePredictionHead
        >>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=4))
        >>> output = torch.randn(1, 6, 4)
        >>> attention_mask = torch.tensor([[1, 1, 1, 1, 0, 0]])
        >>> input_ids = torch.tensor([[1, 10, 20, 2, 0, 0]])
        >>> new_out, new_mask, new_ids = head.remove_special_tokens(output, attention_mask, input_ids)
        >>> output.shape[1], new_out.shape[1]
        (6, 4)
        >>> new_mask
        tensor([[1, 1, 0, 0]])
    """
    if self.bos_token_id is not None:
        output = output[..., 1:, :]
        attention_mask = attention_mask[..., 1:]
        if input_ids is not None:
            input_ids = input_ids[..., 1:]
    if self.eos_token_id is not None:
        if input_ids is not None:
            eos_mask = input_ids.ne(self.eos_token_id).to(output.device)
            if isinstance(input_ids, Tensor):
                input_ids.masked_fill_(~eos_mask, self.pad_token_id or 0)
            if isinstance(eos_mask, NestedTensor):
                eos_mask = eos_mask.tensor
            input_ids = input_ids[..., :-1]
        else:
            last_valid_indices = attention_mask.sum(dim=-1) - 1
            seq_length = attention_mask.size(-1)
            eos_mask = torch.arange(seq_length, device=output.device) != last_valid_indices.unsqueeze(1)
        output = (output * eos_mask.unsqueeze(-1))[..., :-1, :]
        attention_mask = (attention_mask * eos_mask)[..., :-1]
    if attention_mask is not None:
        output = output * attention_mask.unsqueeze(-1)
    return output, attention_mask, input_ids

remove_special_tokens_2d

Python
remove_special_tokens_2d(output: Tensor, attention_mask: Tensor, input_ids: Tensor | None = None) -> Tuple[Tensor, Tensor, Tensor]

Remove special tokens from 2D outputs like contact maps or pairwise interaction matrices.

Extends remove_special_tokens to handle 2D outputs where both sequence dimensions need special token removal. This is crucial for contact prediction and structure analysis tasks where the output represents pairwise relationships between residues.

The method removes: - BOS tokens: Strips first row and column from the 2D output - EOS tokens: Strips rows/columns after EOS positions and the EOS positions themselves - Updates attention mask: Creates 2D mask from 1D sequence mask

Parameters:

Name Type Description Default
output
Tensor

2D model output of shape (batch_size, seq_len, seq_len, channels).

required
attention_mask
Tensor

1D attention mask of shape (batch_size, seq_len).

required
input_ids
Tensor | None

Optional input token IDs of shape (batch_size, seq_len).

None

Returns:

Type Description
Tuple[Tensor, Tensor, Tensor]

Tuple containing: - output: Trimmed 2D output with special tokens removed from both dimensions - attention_mask: 2D attention mask of shape (batch_size, new_len, new_len) - input_ids: Trimmed input IDs (if provided) or unchanged input

Examples:

Python Console Session
>>> import torch
>>> from multimolecule.models.configuration_utils import PreTrainedConfig
>>> from multimolecule.modules.heads.generic import BasePredictionHead
>>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=4))
>>> output = torch.randn(1, 5, 5, 1)
>>> input_ids = torch.tensor([[1, 10, 20, 2, 0]])
>>> attention_mask = torch.tensor([[1, 1, 1, 1, 0]])
>>> new_out, new_mask, new_ids = head.remove_special_tokens_2d(output, attention_mask, input_ids)
>>> output.shape, new_out.shape
(torch.Size([1, 5, 5, 1]), torch.Size([1, 3, 3, 1]))
>>> new_mask.shape
torch.Size([1, 3, 3])
Source code in multimolecule/modules/heads/generic.py
Python
def remove_special_tokens_2d(
    self, output: Tensor, attention_mask: Tensor, input_ids: Tensor | None = None
) -> Tuple[Tensor, Tensor, Tensor]:
    r"""
    Remove special tokens from 2D outputs like contact maps or pairwise interaction matrices.

    Extends `remove_special_tokens` to handle 2D outputs where both sequence dimensions
    need special token removal. This is crucial for contact prediction and structure
    analysis tasks where the output represents pairwise relationships between residues.

    The method removes:
    - BOS tokens: Strips first row and column from the 2D output
    - EOS tokens: Strips rows/columns after EOS positions and the EOS positions themselves
    - Updates attention mask: Creates 2D mask from 1D sequence mask

    Args:
        output: 2D model output of shape `(batch_size, seq_len, seq_len, channels)`.
        attention_mask: 1D attention mask of shape `(batch_size, seq_len)`.
        input_ids: Optional input token IDs of shape `(batch_size, seq_len)`.

    Returns:
        Tuple containing:
            - output: Trimmed 2D output with special tokens removed from both dimensions
            - attention_mask: 2D attention mask of shape `(batch_size, new_len, new_len)`
            - input_ids: Trimmed input IDs (if provided) or unchanged input

    Examples:
        >>> import torch
        >>> from multimolecule.models.configuration_utils import PreTrainedConfig
        >>> from multimolecule.modules.heads.generic import BasePredictionHead
        >>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=4))
        >>> output = torch.randn(1, 5, 5, 1)
        >>> input_ids = torch.tensor([[1, 10, 20, 2, 0]])
        >>> attention_mask = torch.tensor([[1, 1, 1, 1, 0]])
        >>> new_out, new_mask, new_ids = head.remove_special_tokens_2d(output, attention_mask, input_ids)
        >>> output.shape, new_out.shape
        (torch.Size([1, 5, 5, 1]), torch.Size([1, 3, 3, 1]))
        >>> new_mask.shape
        torch.Size([1, 3, 3])
    """
    if self.bos_token_id is not None:
        output = output[..., 1:, 1:, :]
        attention_mask = attention_mask[..., 1:]
        if input_ids is not None:
            input_ids = input_ids[..., 1:]
    if self.eos_token_id is not None:
        if input_ids is not None:
            eos_mask = input_ids.ne(self.eos_token_id).to(output.device)
            if isinstance(input_ids, Tensor):
                input_ids.masked_fill_(~eos_mask, self.pad_token_id or 0)
            if isinstance(eos_mask, NestedTensor):
                eos_mask = eos_mask.tensor
            input_ids = input_ids[..., :-1]
        else:
            last_valid_indices = attention_mask.sum(dim=-1) - 1
            seq_length = attention_mask.size(-1)
            eos_mask = torch.arange(seq_length, device=output.device) != last_valid_indices.unsqueeze(1)
        attention_mask = (attention_mask * eos_mask)[..., :-1]
        eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
        output = (output * eos_mask.unsqueeze(-1))[..., :-1, :-1, :]
    attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
    if attention_mask is not None:
        output = output * attention_mask.unsqueeze(-1)
    return output, attention_mask, input_ids

symmetrize staticmethod

Python
symmetrize(x: Tensor) -> Tensor

Make output symmetric by averaging the tensor with its transpose.

Parameters:

Name Type Description Default
x
Tensor

Input tensor of shape (batch_size, seq_len, seq_len, channels).

required

Returns:

Type Description
Tensor

Symmetric tensor with the same shape as input.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule.modules.heads.generic import BasePredictionHead
>>> x = torch.tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]])
>>> x.squeeze(-1)
tensor([[[1., 2.],
         [3., 4.]]])
>>> symmetric = BasePredictionHead.symmetrize(x)
>>> symmetric.squeeze(-1)
tensor([[[1.0000, 2.5000],
         [2.5000, 4.0000]]])
>>> torch.allclose(symmetric, symmetric.transpose(1, 2))
True
Source code in multimolecule/modules/heads/generic.py
Python
@staticmethod
def symmetrize(x: Tensor) -> Tensor:
    r"""
    Make output symmetric by averaging the tensor with its transpose.

    Args:
        x: Input tensor of shape (batch_size, seq_len, seq_len, channels).

    Returns:
        Symmetric tensor with the same shape as input.

    Examples:
        >>> import torch
        >>> from multimolecule.modules.heads.generic import BasePredictionHead
        >>> x = torch.tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]])
        >>> x.squeeze(-1)
        tensor([[[1., 2.],
                 [3., 4.]]])
        >>> symmetric = BasePredictionHead.symmetrize(x)
        >>> symmetric.squeeze(-1)
        tensor([[[1.0000, 2.5000],
                 [2.5000, 4.0000]]])
        >>> torch.allclose(symmetric, symmetric.transpose(1, 2))
        True
    """
    return (x + x.transpose(1, 2)) / 2

average_product_correct staticmethod

Python
average_product_correct(x: Tensor) -> Tensor

Perform Average Product Correction (APC) to remove systematic biases from contact maps.

APC removes row and column biases that arise from varying residue frequencies and structural preferences in molecular contact maps. It subtracts the expected contact probability based on marginal frequencies to reveal genuine structural interactions.

The correction formula: corrected = original - (row_sums × col_sums) / total_sum

This is essential for accurate contact prediction across DNA, RNA, and protein structures.

Parameters:

Name Type Description Default
x
Tensor

Contact map tensor of shape (batch_size, seq_len, seq_len, channels)

required

Returns:

Type Description
Tensor

Bias-corrected contact map with the same shape as input

Note

This correction removes spurious correlations caused by sequence composition bias, making genuine molecular contacts stand out more clearly from background noise.

Examples:

Python Console Session
>>> import torch
>>> from multimolecule.modules.heads.generic import BasePredictionHead
>>> x = torch.tensor([[[[0.8, 0.6], [0.7, 0.5]], [[0.6, 0.4], [0.5, 0.3]]]])
>>> x.squeeze(-1)
tensor([[[[0.8000, 0.6000],
          [0.7000, 0.5000]],

         [[0.6000, 0.4000],
          [0.5000, 0.3000]]]])
>>> corrected = BasePredictionHead.average_product_correct(x)
>>> corrected.squeeze(-1)
tensor([[[[-0.0077, -0.0111],
          [ 0.0077,  0.0111]],

         [[ 0.0077,  0.0111],
          [-0.0077, -0.0111]]]])
>>> row_sums = corrected.sum(dim=2).squeeze()
>>> col_sums = corrected.sum(dim=1).squeeze()
>>> torch.allclose(row_sums, torch.tensor([[0.0, 0.0], [-0.0, -0.0]]), atol=1e-6)
True
>>> torch.allclose(col_sums, torch.tensor([[0.0, 0.0], [-0.0, -0.0]]), atol=1e-6)
True
Source code in multimolecule/modules/heads/generic.py
Python
@staticmethod
def average_product_correct(x: Tensor) -> Tensor:
    r"""Perform Average Product Correction (APC) to remove systematic biases from contact maps.

    APC removes row and column biases that arise from varying residue frequencies and
    structural preferences in molecular contact maps. It subtracts the expected contact
    probability based on marginal frequencies to reveal genuine structural interactions.

    The correction formula: `corrected = original - (row_sums × col_sums) / total_sum`

    This is essential for accurate contact prediction across DNA, RNA, and protein structures.

    Args:
        x: Contact map tensor of shape `(batch_size, seq_len, seq_len, channels)`

    Returns:
        Bias-corrected contact map with the same shape as input

    Note:
        This correction removes spurious correlations caused by sequence composition bias,
        making genuine molecular contacts stand out more clearly from background noise.

    Examples:
        >>> import torch
        >>> from multimolecule.modules.heads.generic import BasePredictionHead
        >>> x = torch.tensor([[[[0.8, 0.6], [0.7, 0.5]], [[0.6, 0.4], [0.5, 0.3]]]])
        >>> x.squeeze(-1)
        tensor([[[[0.8000, 0.6000],
                  [0.7000, 0.5000]],
        <BLANKLINE>
                 [[0.6000, 0.4000],
                  [0.5000, 0.3000]]]])
        >>> corrected = BasePredictionHead.average_product_correct(x)
        >>> corrected.squeeze(-1)
        tensor([[[[-0.0077, -0.0111],
                  [ 0.0077,  0.0111]],
        <BLANKLINE>
                 [[ 0.0077,  0.0111],
                  [-0.0077, -0.0111]]]])
        >>> row_sums = corrected.sum(dim=2).squeeze()
        >>> col_sums = corrected.sum(dim=1).squeeze()
        >>> torch.allclose(row_sums, torch.tensor([[0.0, 0.0], [-0.0, -0.0]]), atol=1e-6)
        True
        >>> torch.allclose(col_sums, torch.tensor([[0.0, 0.0], [-0.0, -0.0]]), atol=1e-6)
        True
    """
    return x - x.sum(1, keepdims=True) * x.sum(2, keepdims=True) / x.sum((1, 2), keepdims=True)

PredictionHead

Bases: BasePredictionHead

Head for all-level of tasks.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/modules/heads/generic.py
Python
class PredictionHead(BasePredictionHead):
    r"""
    Head for all-level of tasks.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.dropout = nn.Dropout(self.config.dropout)
        self.transform = HEAD_TRANSFORMS_HF.build(self.config)
        self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias)
        self.activation = ACT2FN[self.config.act] if self.config.act is not None else None
        self.criterion = CRITERIONS.build(self.config)

    def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput:
        r"""
        Forward pass of the PredictionHead.

        Args:
            embeddings: The embeddings to be passed through the head.
            labels: The labels for the head.
        """
        if kwargs:
            warn(
                f"The following arguments are not applicable to {self.__class__.__name__} "
                f"and will be ignored: {kwargs.keys()}"
            )
        output = self.dropout(embeddings)
        output = self.transform(output)
        output = self.decoder(output)
        if self.activation is not None:
            output = self.activation(output)
        if labels is not None:
            if isinstance(labels, NestedTensor):
                if isinstance(output, Tensor):
                    output = labels.nested_like(output, strict=False)
                return HeadOutput(output, self.criterion(output.concat, labels.concat))
            return HeadOutput(output, self.criterion(output, labels))
        return HeadOutput(output)

forward

Python
forward(embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput

Forward pass of the PredictionHead.

Parameters:

Name Type Description Default
embeddings
Tensor

The embeddings to be passed through the head.

required
labels
Tensor | None

The labels for the head.

required
Source code in multimolecule/modules/heads/generic.py
Python
def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput:
    r"""
    Forward pass of the PredictionHead.

    Args:
        embeddings: The embeddings to be passed through the head.
        labels: The labels for the head.
    """
    if kwargs:
        warn(
            f"The following arguments are not applicable to {self.__class__.__name__} "
            f"and will be ignored: {kwargs.keys()}"
        )
    output = self.dropout(embeddings)
    output = self.transform(output)
    output = self.decoder(output)
    if self.activation is not None:
        output = self.activation(output)
    if labels is not None:
        if isinstance(labels, NestedTensor):
            if isinstance(output, Tensor):
                output = labels.nested_like(output, strict=False)
            return HeadOutput(output, self.criterion(output.concat, labels.concat))
        return HeadOutput(output, self.criterion(output, labels))
    return HeadOutput(output)

multimolecule.modules.heads.output

HeadOutput dataclass

Bases: ModelOutput

Output of a prediction head.

Parameters:

Name Type Description Default

logits

FloatTensor

The prediction logits from the head.

required

loss

FloatTensor | None

The loss from the head. Defaults to None.

None
Source code in multimolecule/modules/heads/output.py
Python
@dataclass
class HeadOutput(ModelOutput):
    r"""
    Output of a prediction head.

    Args:
        logits: The prediction logits from the head.
        loss: The loss from the head.
            Defaults to None.
    """

    logits: FloatTensor
    loss: FloatTensor | None = None