Skip to content

heads

heads provide a collection of pre-defined prediction heads.

heads take in either a ModelOutupt, a dict, or a tuple as input. It automatically looks for the model output required for prediction and processes it accordingly.

Some prediction heads may require additional information, such as the attention_mask or the input_ids, like ContactPredictionHead. These additional arguments can be passed in as arguments/keyword arguments.

Note that heads use the same ModelOutupt conventions as the 🤗 Transformers. If the model output is a tuple, we consider the first element as the pooler_output, the second element as the last_hidden_state, and the last element as the attention_map. It is the user’s responsibility to ensure that the model output is correctly formatted.

If the model output is a ModelOutupt or a dict, the heads will look for the HeadConfig.output_name from the model output. You can specify the output_name in the HeadConfig to ensure that the heads can correctly locate the required tensor.

multimolecule.modules.heads.config

HeadConfig

Bases: BaseHeadConfig

Configuration class for a prediction head.

参数:

名称 类型 描述 默认

num_labels

Number of labels to use in the last layer added to the model, typically for a classification task.

Head should look for Config.num_labels if is None.

必需

problem_type

Problem type for XxxForYyyPrediction models. Can be one of "binary", "regression", "multiclass" or "multilabel".

Head should look for Config.problem_type if is None.

必需

hidden_size

Dimensionality of the encoder layers and the pooler layer.

Head should look for Config.hidden_size if is None.

必需

dropout

The dropout ratio for the hidden states.

必需

transform

The transform operation applied to hidden states.

必需

transform_act

The activation function of transform applied to hidden states.

必需

bias

Whether to apply bias to the final prediction layer.

必需

act

The activation function of the final prediction output.

必需

layer_norm_eps

The epsilon used by the layer normalization layers.

必需

output_name

The name of the tensor required in model outputs.

If is None, will use the default output name of the corresponding head.

必需

type

The type of the head in the model.

This is used by [MultiMoleculeModel][multimolecule.MultiMoleculeModel] to construct heads.

必需
Source code in multimolecule/modules/heads/config.py
Python
class HeadConfig(BaseHeadConfig):
    r"""
    Configuration class for a prediction head.

    Args:
        num_labels:
            Number of labels to use in the last layer added to the model, typically for a classification task.

            Head should look for [`Config.num_labels`][multimolecule.PreTrainedConfig] if is `None`.
        problem_type:
            Problem type for `XxxForYyyPrediction` models. Can be one of `"binary"`, `"regression"`,
            `"multiclass"` or `"multilabel"`.

            Head should look for [`Config.problem_type`][multimolecule.PreTrainedConfig] if is `None`.
        hidden_size:
            Dimensionality of the encoder layers and the pooler layer.

            Head should look for [`Config.hidden_size`][multimolecule.PreTrainedConfig] if is `None`.
        dropout:
            The dropout ratio for the hidden states.
        transform:
            The transform operation applied to hidden states.
        transform_act:
            The activation function of transform applied to hidden states.
        bias:
            Whether to apply bias to the final prediction layer.
        act:
            The activation function of the final prediction output.
        layer_norm_eps:
            The epsilon used by the layer normalization layers.
        output_name:
            The name of the tensor required in model outputs.

            If is `None`, will use the default output name of the corresponding head.
        type:
            The type of the head in the model.

            This is used by [`MultiMoleculeModel`][multimolecule.MultiMoleculeModel] to construct heads.
    """

    num_labels: Optional[int] = None
    problem_type: Optional[str] = None
    hidden_size: Optional[int] = None
    dropout: float = 0.0
    transform: Optional[str] = None
    transform_act: Optional[str] = "gelu"
    bias: bool = True
    act: Optional[str] = None
    layer_norm_eps: float = 1e-12
    output_name: Optional[str] = None
    type: Optional[str] = None

MaskedLMHeadConfig

Bases: BaseHeadConfig

Configuration class for a Masked Language Modeling head.

参数:

名称 类型 描述 默认

hidden_size

Dimensionality of the encoder layers and the pooler layer.

Head should look for Config.hidden_size if is None.

必需

dropout

The dropout ratio for the hidden states.

必需

transform

The transform operation applied to hidden states.

必需

transform_act

The activation function of transform applied to hidden states.

必需

bias

Whether to apply bias to the final prediction layer.

必需

act

The activation function of the final prediction output.

必需

layer_norm_eps

The epsilon used by the layer normalization layers.

必需

output_name

The name of the tensor required in model outputs.

If is None, will use the default output name of the corresponding head.

必需
Source code in multimolecule/modules/heads/config.py
Python
class MaskedLMHeadConfig(BaseHeadConfig):
    r"""
    Configuration class for a Masked Language Modeling head.

    Args:
        hidden_size:
            Dimensionality of the encoder layers and the pooler layer.

            Head should look for [`Config.hidden_size`][multimolecule.PreTrainedConfig] if is `None`.
        dropout:
            The dropout ratio for the hidden states.
        transform:
            The transform operation applied to hidden states.
        transform_act:
            The activation function of transform applied to hidden states.
        bias:
            Whether to apply bias to the final prediction layer.
        act:
            The activation function of the final prediction output.
        layer_norm_eps:
            The epsilon used by the layer normalization layers.
        output_name:
            The name of the tensor required in model outputs.

            If is `None`, will use the default output name of the corresponding head.
    """

    hidden_size: Optional[int] = None
    dropout: float = 0.0
    transform: Optional[str] = "nonlinear"
    transform_act: Optional[str] = "gelu"
    bias: bool = True
    act: Optional[str] = None
    layer_norm_eps: float = 1e-12
    output_name: Optional[str] = None

multimolecule.modules.heads.sequence

SequencePredictionHead

Bases: PredictionHead

Head for tasks in sequence-level.

参数:

名称 类型 描述 默认

config

PreTrainedConfig

The configuration object for the model.

必需

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/modules/heads/sequence.py
Python
@HEADS.register("sequence")
class SequencePredictionHead(PredictionHead):
    r"""
    Head for tasks in sequence-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "pooler_output"
    r"""The default output to use for the head."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the SequencePredictionHead.

        Args:
            outputs: The outputs of the model.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[1]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
        return super().forward(output, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'pooler_output'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...], labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the SequencePredictionHead.

参数:

名称 类型 描述 默认
outputs
ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...]

The outputs of the model.

必需
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
源代码位于: multimolecule/modules/heads/sequence.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the SequencePredictionHead.

    Args:
        outputs: The outputs of the model.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[1]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
    return super().forward(output, labels, **kwargs)

multimolecule.modules.heads.token

TokenPredictionHead

Bases: PredictionHead

Head for tasks in token-level.

参数:

名称 类型 描述 默认

config

PreTrainedConfig

The configuration object for the model.

必需

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/modules/heads/token.py
Python
@HEADS.token.register("single", default=True)
class TokenPredictionHead(PredictionHead):
    r"""
    Head for tasks in token-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the TokenPredictionHead.

        Args:
            outputs: The outputs of the model.
            attention_mask: The attention mask for the inputs.
            input_ids: The input ids for the inputs.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

        if attention_mask is None:
            attention_mask = self.get_attention_mask(input_ids)
        output, _, _ = self.remove_special_tokens(output, attention_mask, input_ids)

        return super().forward(output, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...], attention_mask: Tensor | None = None, input_ids: NestedTensor | Tensor | None = None, labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the TokenPredictionHead.

参数:

名称 类型 描述 默认
outputs
ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...]

The outputs of the model.

必需
attention_mask
Tensor | None

The attention mask for the inputs.

None
input_ids
NestedTensor | Tensor | None

The input ids for the inputs.

None
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
源代码位于: multimolecule/modules/heads/token.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
    attention_mask: Tensor | None = None,
    input_ids: NestedTensor | Tensor | None = None,
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the TokenPredictionHead.

    Args:
        outputs: The outputs of the model.
        attention_mask: The attention mask for the inputs.
        input_ids: The input ids for the inputs.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[0]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

    if attention_mask is None:
        attention_mask = self.get_attention_mask(input_ids)
    output, _, _ = self.remove_special_tokens(output, attention_mask, input_ids)

    return super().forward(output, labels, **kwargs)

TokenKMerHead

Bases: PredictionHead

Head for tasks in token-level with kmer inputs.

参数:

名称 类型 描述 默认

config

PreTrainedConfig

The configuration object for the model.

必需

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/modules/heads/token.py
Python
@HEADS.token.register("kmer")
class TokenKMerHead(PredictionHead):
    r"""
    Head for tasks in token-level with kmer inputs.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.nmers = config.nmers

        # Do not pass bos_token_id and eos_token_id to unfold_kmer_embeddings
        # As they will be removed in preprocess
        self.unfold_kmer_embeddings = partial(unfold_kmer_embeddings, nmers=self.nmers)

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the TokenKMerHead.

        Args:
            outputs: The outputs of the model.
            attention_mask: The attention mask for the inputs.
            input_ids: The input ids for the inputs.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

        if attention_mask is None:
            attention_mask = self.get_attention_mask(input_ids)
        output, attention_mask, _ = self.remove_special_tokens(output, attention_mask, input_ids)

        output = self.unfold_kmer_embeddings(output, attention_mask)
        return super().forward(output, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...], attention_mask: Tensor | None = None, input_ids: NestedTensor | Tensor | None = None, labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the TokenKMerHead.

参数:

名称 类型 描述 默认
outputs
ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...]

The outputs of the model.

必需
attention_mask
Tensor | None

The attention mask for the inputs.

None
input_ids
NestedTensor | Tensor | None

The input ids for the inputs.

None
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
源代码位于: multimolecule/modules/heads/token.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
    attention_mask: Tensor | None = None,
    input_ids: NestedTensor | Tensor | None = None,
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the TokenKMerHead.

    Args:
        outputs: The outputs of the model.
        attention_mask: The attention mask for the inputs.
        input_ids: The input ids for the inputs.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[0]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

    if attention_mask is None:
        attention_mask = self.get_attention_mask(input_ids)
    output, attention_mask, _ = self.remove_special_tokens(output, attention_mask, input_ids)

    output = self.unfold_kmer_embeddings(output, attention_mask)
    return super().forward(output, labels, **kwargs)

unfold_kmer_embeddings

Python
unfold_kmer_embeddings(embeddings: Tensor, attention_mask: Tensor, nmers: int, bos_token_id: int | None = None, eos_token_id: int | None = None) -> Tensor

Unfold k-mer embeddings to token embeddings.

For k-mer input, each embedding column represents k tokens. This should be fine for sequence level tasks, but sacrifices the resolution for token level tasks. This function unfolds the k-mer embeddings to token embeddings by sliding averaging the k-mer embeddings.

For example:

input tokens = ACGU

2-mer embeddings = [<CLS>, AC, CG, GU, <SEP>].

token embeddings = [<CLS>, AC, (AC + CG) / 2, (CG + GU) / 2, GU, <SEP>].

参数:

名称 类型 描述 默认

embeddings

Tensor

The k-mer embeddings.

必需

attention_mask

Tensor

The attention mask.

必需

nmers

int

The number of tokens in each k-mer.

必需

bos_token_id

int | None

The id of the beginning of sequence token. If not None, the first valid token will not be included in sliding averaging.

None

eos_token_id

int | None

The id of the end of sequence token. If not None, the last valid token will not be included in sliding averaging.

None

返回:

类型 描述
Tensor

The token embeddings.

示例:

Python Console Session
>>> from danling import NestedTensor
>>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(5).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 3, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.0, 2.0, 3.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0]
>>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 4, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 7.0]
>>> embeddings = NestedTensor(torch.arange(7).repeat(2, 1).T, torch.arange(11).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 5, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0, 7.0, 8.0, 8.5, 9.0, 9.5, 10.0, 11.0]
>>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(4).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 2.5, 2.5, 2.5, 2.5, 3.0, 4.0]
>>> embeddings = NestedTensor(torch.arange(1).repeat(2, 1).T, torch.arange(2).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6)
>>> output[0, :, 0].tolist()
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 1.5, 1.5, 1.5, 1.5, 1.5, 2.0]
源代码位于: multimolecule/modules/heads/token.py
Python
def unfold_kmer_embeddings(
    embeddings: Tensor,
    attention_mask: Tensor,
    nmers: int,
    bos_token_id: int | None = None,
    eos_token_id: int | None = None,
) -> Tensor:
    r"""
    Unfold k-mer embeddings to token embeddings.

    For k-mer input, each embedding column represents k tokens.
    This should be fine for sequence level tasks, but sacrifices the resolution for token level tasks.
    This function unfolds the k-mer embeddings to token embeddings by sliding averaging the k-mer embeddings.

    For example:

    input tokens = `ACGU`

    2-mer embeddings = `[<CLS>, AC, CG, GU, <SEP>]`.

    token embeddings = `[<CLS>, AC, (AC + CG) / 2, (CG + GU) / 2, GU, <SEP>]`.

    Args:
        embeddings: The k-mer embeddings.
        attention_mask: The attention mask.
        nmers: The number of tokens in each k-mer.
        bos_token_id: The id of the beginning of sequence token.
            If not None, the first valid token will not be included in sliding averaging.
        eos_token_id: The id of the end of sequence token.
            If not None, the last valid token will not be included in sliding averaging.

    Returns:
        The token embeddings.

    Examples:
        >>> from danling import NestedTensor
        >>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(5).repeat(2, 1).T) + 1
        >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 3, True, True)
        >>> output[0, :, 0].tolist()
        [1.0, 2.0, 2.0, 2.0, 3.0, 0.0, 0.0]
        >>> output[1, :, 0].tolist()
        [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0]
        >>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1
        >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 4, True, True)
        >>> output[0, :, 0].tolist()
        [1.0, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0]
        >>> output[1, :, 0].tolist()
        [1.0, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 7.0]
        >>> embeddings = NestedTensor(torch.arange(7).repeat(2, 1).T, torch.arange(11).repeat(2, 1).T) + 1
        >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 5, True, True)
        >>> output[0, :, 0].tolist()
        [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0]
        >>> output[1, :, 0].tolist()
        [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0, 7.0, 8.0, 8.5, 9.0, 9.5, 10.0, 11.0]
        >>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(4).repeat(2, 1).T) + 1
        >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6, True, True)
        >>> output[0, :, 0].tolist()
        [1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.0]
        >>> output[1, :, 0].tolist()
        [1.0, 2.0, 2.5, 2.5, 2.5, 2.5, 2.5, 3.0, 4.0]
        >>> embeddings = NestedTensor(torch.arange(1).repeat(2, 1).T, torch.arange(2).repeat(2, 1).T) + 1
        >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6)
        >>> output[0, :, 0].tolist()
        [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0]
        >>> output[1, :, 0].tolist()
        [1.0, 1.5, 1.5, 1.5, 1.5, 1.5, 2.0]
    """

    batch_size, seq_length, hidden_size = embeddings.size()
    last_valid_indices = attention_mask.sum(dim=-1)
    output = torch.zeros(batch_size, seq_length + nmers - 1, hidden_size, device=embeddings.device)
    for index, (tensor, seq_length) in enumerate(zip(embeddings, last_valid_indices)):
        embedding = tensor[:seq_length]
        if bos_token_id is not None:
            embedding = embedding[1:]
        if eos_token_id is not None:
            embedding = embedding[:-1]
        if len(embedding) > nmers:
            begin = torch.stack([embedding[:i].mean(0) for i in range(1, nmers)])
            medium = embedding.unfold(0, nmers, 1).mean(-1)
            end = torch.stack([embedding[-i:].mean(0) for i in range(nmers - 1, 0, -1)])
            embedding = torch.cat([begin, medium, end])
        elif len(embedding) > 2:
            begin = torch.stack([embedding[:i].mean(0) for i in range(1, len(embedding))])
            end = torch.stack([embedding[-i:].mean(0) for i in range(nmers, 0, -1)])
            embedding = torch.cat([begin, end])
        elif len(embedding) == 2:
            medium = embedding.mean(0).repeat(nmers - 1, 1)
            embedding = torch.cat([embedding[0][None, :], medium, embedding[1][None, :]])
        elif len(embedding) == 1:
            embedding = embedding.repeat(nmers, 1)
        else:
            raise ValueError("Sequence length is less than nmers.")
        if bos_token_id is not None:
            embedding = torch.cat([tensor[0][None, :], embedding])
        if eos_token_id is not None:
            embedding = torch.cat([embedding, tensor[seq_length - 1][None, :]])
        output[index, : seq_length + nmers - 1] = embedding
    return output

multimolecule.modules.heads.contact

ContactPredictionHead

Bases: BasePredictionHead

Head for tasks in contact-level.

参数:

名称 类型 描述 默认

config

PreTrainedConfig

The configuration object for the model.

必需

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None

示例:

Python Console Session
1
2
3
4
5
6
7
>>> import torch
>>> from multimolecule.models import PreTrainedConfig
>>> from multimolecule.modules.heads import ContactPredictionHead
>>> config = PreTrainedConfig(hidden_size=8)
>>> head = ContactPredictionHead(config)
>>> input = torch.randn(1, 28, config.hidden_size)
>>> output = head({"last_hidden_state": input}, attention_mask=torch.ones(1, 28))
Source code in multimolecule/modules/heads/contact.py
Python
@HEADS.contact.logits.register("linear", default=True)
class ContactPredictionHead(BasePredictionHead):
    r"""
    Head for tasks in contact-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.

    Examples:
        >>> import torch
        >>> from multimolecule.models import PreTrainedConfig
        >>> from multimolecule.modules.heads import ContactPredictionHead
        >>> config = PreTrainedConfig(hidden_size=8)
        >>> head = ContactPredictionHead(config)
        >>> input = torch.randn(1, 28, config.hidden_size)
        >>> output = head({"last_hidden_state": input}, attention_mask=torch.ones(1, 28))
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    require_attentions: bool = False
    r"""Whether the head requires attentions."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        out_channels: int = self.config.hidden_size  # type: ignore[assignment]
        self.dropout = nn.Dropout(self.config.dropout)
        self.transform = HEAD_TRANSFORMS_HF.build(self.config)
        self.decoder = nn.Linear(out_channels, self.num_labels, bias=self.config.bias)
        self.activation = ACT2FN[self.config.act] if self.config.act is not None else None
        self.criterion = CRITERIONS.build(self.config)

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

        if attention_mask is None:
            attention_mask = self.get_attention_mask(input_ids)
        output, _, _ = self.remove_special_tokens(output, attention_mask, input_ids)

        output = self.dropout(output)
        output = self.transform(output)
        contact_map = output.unsqueeze(1) * output.unsqueeze(2)
        contact_map = self.decoder(contact_map)
        contact_map = self.symmetrize(contact_map)
        if self.activation is not None:
            contact_map = self.activation(contact_map)

        if labels is not None:
            if isinstance(labels, NestedTensor):
                if isinstance(contact_map, Tensor):
                    contact_map = labels.nested_like(contact_map, strict=False)
                return HeadOutput(contact_map, self.criterion(contact_map.concat, labels.concat))
            return HeadOutput(contact_map, self.criterion(contact_map, labels))
        return HeadOutput(contact_map)

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

require_attentions class-attribute instance-attribute

Python
require_attentions: bool = False

Whether the head requires attentions.

ContactPredictionResNetHead

Bases: ContactPredictionHead

Head for tasks in contact-level.

参数:

名称 类型 描述 默认

config

PreTrainedConfig

The configuration object for the model.

必需

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None

示例:

Python Console Session
1
2
3
4
5
6
7
>>> import torch
>>> from multimolecule.models import PreTrainedConfig
>>> from multimolecule.modules.heads import ContactPredictionResNetHead
>>> config = PreTrainedConfig(hidden_size=32)
>>> head = ContactPredictionResNetHead(config)
>>> input = torch.randn(1, 28, config.hidden_size)
>>> output = head({"last_hidden_state": input}, attention_mask=torch.ones(1, 28))
Source code in multimolecule/modules/heads/contact.py
Python
@HEADS.contact.logits.register("resnet")
class ContactPredictionResNetHead(ContactPredictionHead):
    r"""
    Head for tasks in contact-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.

    Examples:
        >>> import torch
        >>> from multimolecule.models import PreTrainedConfig
        >>> from multimolecule.modules.heads import ContactPredictionResNetHead
        >>> config = PreTrainedConfig(hidden_size=32)
        >>> head = ContactPredictionResNetHead(config)
        >>> input = torch.randn(1, 28, config.hidden_size)
        >>> output = head({"last_hidden_state": input}, attention_mask=torch.ones(1, 28))
    """

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.decoder = ResNet(
            num_layers=self.config.get("num_layers", 6),
            hidden_size=self.config.hidden_size,  # type: ignore[arg-type]
            block=self.config.get("block", "auto"),
            num_channels=self.config.get("num_channels"),
            num_labels=self.num_labels,
        )

ContactPredictionUNetHead

Bases: ContactPredictionHead

Head for tasks in contact-level.

参数:

名称 类型 描述 默认

config

PreTrainedConfig

The configuration object for the model.

必需

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None

示例:

Python Console Session
1
2
3
4
5
6
7
>>> import torch
>>> from multimolecule.models import PreTrainedConfig
>>> from multimolecule.modules.heads import ContactPredictionUNetHead
>>> config = PreTrainedConfig(hidden_size=32)
>>> head = ContactPredictionUNetHead(config)
>>> input = torch.randn(1, 28, config.hidden_size)
>>> output = head({"last_hidden_state": input}, attention_mask=torch.ones(1, 28))
Source code in multimolecule/modules/heads/contact.py
Python
@HEADS.contact.logits.register("unet")
class ContactPredictionUNetHead(ContactPredictionHead):
    r"""
    Head for tasks in contact-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.

    Examples:
        >>> import torch
        >>> from multimolecule.models import PreTrainedConfig
        >>> from multimolecule.modules.heads import ContactPredictionUNetHead
        >>> config = PreTrainedConfig(hidden_size=32)
        >>> head = ContactPredictionUNetHead(config)
        >>> input = torch.randn(1, 28, config.hidden_size)
        >>> output = head({"last_hidden_state": input}, attention_mask=torch.ones(1, 28))
    """

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.decoder = UNet(
            num_layers=self.config.get("num_layers", 6),
            hidden_size=self.config.hidden_size,  # type: ignore[arg-type]
            block=self.config.get("block", "auto"),
            num_channels=self.config.get("num_channels"),
            num_labels=self.num_labels,
        )

ContactAttentionHead

Bases: BasePredictionHead

Head for tasks in contact-level.

参数:

名称 类型 描述 默认

config

PreTrainedConfig

The configuration object for the model.

必需

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None

示例:

Python Console Session
1
2
3
4
5
6
7
>>> import torch
>>> from multimolecule.models import PreTrainedConfig
>>> from multimolecule.modules.heads import ContactAttentionHead
>>> config = PreTrainedConfig(num_hidden_layers=2, num_attention_heads=4)
>>> head = ContactAttentionHead(config)
>>> input = tuple(torch.randn(1, config.num_attention_heads, 28, 28) for _ in range(config.num_hidden_layers))
>>> output = head({"attentions": input}, attention_mask=torch.ones(1, 28))
Source code in multimolecule/modules/heads/contact.py
Python
@HEADS.contact.attention.register("linear")
class ContactAttentionHead(BasePredictionHead):
    r"""
    Head for tasks in contact-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.

    Examples:
        >>> import torch
        >>> from multimolecule.models import PreTrainedConfig
        >>> from multimolecule.modules.heads import ContactAttentionHead
        >>> config = PreTrainedConfig(num_hidden_layers=2, num_attention_heads=4)
        >>> head = ContactAttentionHead(config)
        >>> input = tuple(torch.randn(1, config.num_attention_heads, 28, 28) for _ in range(config.num_hidden_layers))
        >>> output = head({"attentions": input}, attention_mask=torch.ones(1, 28))
    """

    output_name: str = "attentions"
    r"""The default output to use for the head."""

    require_attentions: bool = True
    r"""Whether the head requires attentions."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        if head_config is None:
            head_config = HeadConfig(hidden_size=config.num_hidden_layers * config.num_attention_heads)
        else:
            head_config.hidden_size = config.num_hidden_layers * config.num_attention_heads
        super().__init__(config, head_config)
        self.dropout = nn.Dropout(self.config.dropout)
        self.transform = HEAD_TRANSFORMS_HF.build(self.config)
        self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias)
        self.activation = ACT2FN[self.config.act] if self.config.act is not None else None
        self.criterion = CRITERIONS.build(self.config)

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Mapping | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[-1]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

        if isinstance(output, (list, tuple)):
            output = torch.stack(output, 1)
        contact_map = output.flatten(1, 2).permute(0, 2, 3, 1)

        if attention_mask is None:
            attention_mask = self.get_attention_mask(input_ids)
        contact_map, _, _ = self.remove_special_tokens_2d(contact_map, attention_mask, input_ids)

        contact_map = self.dropout(contact_map)
        contact_map = self.transform(contact_map)
        contact_map = self.decoder(contact_map)
        contact_map = self.symmetrize(contact_map)
        if self.activation is not None:
            contact_map = self.activation(contact_map)

        if labels is not None:
            if isinstance(labels, NestedTensor):
                if isinstance(contact_map, Tensor):
                    contact_map = labels.nested_like(contact_map, strict=False)
                return HeadOutput(contact_map, self.criterion(contact_map.concat, labels.concat))
            return HeadOutput(contact_map, self.criterion(contact_map, labels))
        return HeadOutput(contact_map)

output_name class-attribute instance-attribute

Python
output_name: str = 'attentions'

The default output to use for the head.

require_attentions class-attribute instance-attribute

Python
require_attentions: bool = True

Whether the head requires attentions.

ContactAttentionResNetHead

Bases: ContactAttentionHead

Head for tasks in contact-level.

参数:

名称 类型 描述 默认

config

PreTrainedConfig

The configuration object for the model.

必需

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None

示例:

Python Console Session
1
2
3
4
5
6
7
>>> import torch
>>> from multimolecule.models import PreTrainedConfig
>>> from multimolecule.modules.heads import ContactAttentionResNetHead
>>> config = PreTrainedConfig(num_hidden_layers=8, num_attention_heads=4)
>>> head = ContactAttentionResNetHead(config)
>>> input = tuple(torch.randn(1, config.num_attention_heads, 28, 28) for _ in range(config.num_hidden_layers))
>>> output = head({"attentions": input}, attention_mask=torch.ones(1, 28))
Source code in multimolecule/modules/heads/contact.py
Python
@HEADS.contact.attention.register("resnet")
class ContactAttentionResNetHead(ContactAttentionHead):
    r"""
    Head for tasks in contact-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.

    Examples:
        >>> import torch
        >>> from multimolecule.models import PreTrainedConfig
        >>> from multimolecule.modules.heads import ContactAttentionResNetHead
        >>> config = PreTrainedConfig(num_hidden_layers=8, num_attention_heads=4)
        >>> head = ContactAttentionResNetHead(config)
        >>> input = tuple(torch.randn(1, config.num_attention_heads, 28, 28) for _ in range(config.num_hidden_layers))
        >>> output = head({"attentions": input}, attention_mask=torch.ones(1, 28))
    """

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.decoder = ResNet(
            num_layers=self.config.get("num_layers", 16),
            hidden_size=self.config.hidden_size,  # type: ignore[arg-type]
            block=self.config.get("block", "auto"),
            num_channels=self.config.get("num_channels"),
            num_labels=self.num_labels,
        )

ContactAttentionUNetHead

Bases: ContactAttentionHead

Head for tasks in contact-level.

参数:

名称 类型 描述 默认

config

PreTrainedConfig

The configuration object for the model.

必需

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None

示例:

Python Console Session
1
2
3
4
5
6
7
>>> import torch
>>> from multimolecule.models import PreTrainedConfig
>>> from multimolecule.modules.heads import ContactAttentionUNetHead
>>> config = PreTrainedConfig(num_hidden_layers=4, num_attention_heads=8)
>>> head = ContactAttentionUNetHead(config)
>>> input = tuple(torch.randn(1, config.num_attention_heads, 28, 28) for _ in range(config.num_hidden_layers))
>>> output = head({"attentions": input}, attention_mask=torch.ones(1, 28))
Source code in multimolecule/modules/heads/contact.py
Python
@HEADS.contact.attention.register("unet")
class ContactAttentionUNetHead(ContactAttentionHead):
    r"""
    Head for tasks in contact-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.

    Examples:
        >>> import torch
        >>> from multimolecule.models import PreTrainedConfig
        >>> from multimolecule.modules.heads import ContactAttentionUNetHead
        >>> config = PreTrainedConfig(num_hidden_layers=4, num_attention_heads=8)
        >>> head = ContactAttentionUNetHead(config)
        >>> input = tuple(torch.randn(1, config.num_attention_heads, 28, 28) for _ in range(config.num_hidden_layers))
        >>> output = head({"attentions": input}, attention_mask=torch.ones(1, 28))
    """

    output_name: str = "attentions"
    r"""The default output to use for the head."""

    require_attentions: bool = True
    r"""Whether the head requires attentions."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.decoder = UNet(
            num_layers=self.config.get("num_layers", 4),
            hidden_size=self.config.hidden_size,  # type: ignore[arg-type]
            block=self.config.get("block", "auto"),
            num_channels=self.config.get("num_channels"),
            num_labels=self.num_labels,
        )

output_name class-attribute instance-attribute

Python
output_name: str = 'attentions'

The default output to use for the head.

require_attentions class-attribute instance-attribute

Python
require_attentions: bool = True

Whether the head requires attentions.

multimolecule.modules.heads.pretrain

MaskedLMHead

Bases: BasePredictionHead

Head for masked language modeling.

参数:

名称 类型 描述 默认

config

PreTrainedConfig

The configuration object for the model.

必需

head_config

MaskedLMHeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/modules/heads/pretrain.py
Python
@HEADS.register("masked_lm")
class MaskedLMHead(BasePredictionHead):
    r"""
    Head for masked language modeling.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    def __init__(
        self, config: PreTrainedConfig, weight: Tensor | None = None, head_config: MaskedLMHeadConfig | None = None
    ):
        if head_config is None:
            head_config = (config.lm_head if hasattr(config, "lm_head") else config.head) or MaskedLMHeadConfig()
        head_config.num_labels = config.vocab_size
        super().__init__(config, head_config)
        self.dropout = nn.Dropout(self.config.dropout)
        self.transform = HEAD_TRANSFORMS_HF.build(self.config)
        self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=False)
        if weight is not None:
            self.decoder.weight = weight
        if self.config.bias:
            self.bias = nn.Parameter(torch.zeros(self.num_labels))
            self.decoder.bias = self.bias
        self.activation = ACT2FN[self.config.act] if self.config.act is not None else None

    def forward(
        self,
        outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
        labels: Tensor | None = None,
        output_name: str | None = None,
    ) -> HeadOutput:
        r"""
        Forward pass of the MaskedLMHead.

        Args:
            outputs: The outputs of the model.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

        output = self.dropout(output)
        output = self.transform(output)
        output = self.decoder(output)
        if self.activation is not None:
            output = self.activation(output)

        if labels is not None:
            if isinstance(labels, NestedTensor):
                if isinstance(output, Tensor):
                    output = labels.nested_like(output, strict=False)
                return HeadOutput(output, F.cross_entropy(output.concat, labels.concat))
            return HeadOutput(output, F.cross_entropy(output.view(-1, self.num_labels), labels.view(-1)))
        return HeadOutput(output)

    def _tie_weights(self):
        self.decoder.bias = self.bias

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...], labels: Tensor | None = None, output_name: str | None = None) -> HeadOutput

Forward pass of the MaskedLMHead.

参数:

名称 类型 描述 默认
outputs
ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...]

The outputs of the model.

必需
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
源代码位于: multimolecule/modules/heads/pretrain.py
Python
def forward(
    self,
    outputs: ModelOutput | Mapping[str, Tensor] | Tuple[Tensor, ...],
    labels: Tensor | None = None,
    output_name: str | None = None,
) -> HeadOutput:
    r"""
    Forward pass of the MaskedLMHead.

    Args:
        outputs: The outputs of the model.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[0]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

    output = self.dropout(output)
    output = self.transform(output)
    output = self.decoder(output)
    if self.activation is not None:
        output = self.activation(output)

    if labels is not None:
        if isinstance(labels, NestedTensor):
            if isinstance(output, Tensor):
                output = labels.nested_like(output, strict=False)
            return HeadOutput(output, F.cross_entropy(output.concat, labels.concat))
        return HeadOutput(output, F.cross_entropy(output.view(-1, self.num_labels), labels.view(-1)))
    return HeadOutput(output)

multimolecule.modules.heads.generic

BasePredictionHead

Bases: Module

Head for all-level of tasks.

参数:

名称 类型 描述 默认

config

PreTrainedConfig

The configuration object for the model.

必需

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/modules/heads/generic.py
Python
class BasePredictionHead(nn.Module):
    r"""
    Head for all-level of tasks.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    config: HeadConfig
    r"""The configuration object for the head."""

    num_labels: int
    r"""Number of labels for the head."""

    output_name: str | None
    r"""The default output to use for the head."""

    require_attentions: bool = False
    r"""Whether the head requires attentions from the model."""

    bos_token_id: int | None = None
    r"""The ID of the beginning-of-sequence token. Usually is an alias of `cls_token_id`."""

    pad_token_id: int | None = None
    r"""The ID of the padding token."""

    eos_token_id: int | None = None
    r"""The ID of the end-of-sequence token. In rare cases, it is an alias of `sep_token_id`."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__()
        if head_config is None:
            head_config = config.head or HeadConfig(num_labels=config.num_labels)
        if not isinstance(head_config, HeadConfig):
            head_config = HeadConfig(head_config)
        if not head_config.num_labels:
            head_config.num_labels = config.num_labels
        if not head_config.hidden_size:
            head_config.hidden_size = config.hidden_size
        if not head_config.problem_type:
            head_config.problem_type = config.problem_type
        self.config = head_config
        self.bos_token_id = config.bos_token_id
        self.eos_token_id = config.eos_token_id
        self.pad_token_id = config.pad_token_id
        self.num_labels = self.config.num_labels  # type: ignore[assignment]
        if getattr(self.config, "output_name", None) is not None:
            self.output_name = self.config.output_name

    def get_attention_mask(self, input_ids: NestedTensor | Tensor) -> Tensor:
        r"""
        Generate attention mask from input IDs or extract from NestedTensor.

        Creates a binary attention mask indicating which tokens should be attended to (1)
        and which should be ignored (0, typically padding tokens). For NestedTensor inputs,
        extracts the pre-computed mask. For regular tensors, compares against pad_token_id.

        Args:
            input_ids: Input token IDs as either a NestedTensor with embedded mask
                or a regular Tensor of shape `(batch_size, seq_len)`.

        Returns:
            Binary attention mask of shape `(batch_size, seq_len)` where 1 indicates
            tokens to attend to and 0 indicates tokens to ignore.

        Raises:
            ValueError: If input_ids is None or if pad_token_id is None when needed
                for regular Tensor inputs.

        Examples:
            >>> import torch
            >>> from multimolecule.models.configuration_utils import PreTrainedConfig
            >>> from multimolecule.modules.heads.generic import BasePredictionHead
            >>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=128))
            >>> input_ids = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 0, 0, 0]])
            >>> mask = head.get_attention_mask(input_ids)
            >>> mask
            tensor([[1, 1, 1, 0, 0],
                    [1, 1, 0, 0, 0]], dtype=torch.int32)
        """
        if isinstance(input_ids, NestedTensor):
            return input_ids.mask
        if input_ids is None:
            raise ValueError(
                f"Unable to infer attention mask for {self.__class__.__name__}, because input_ids is None."
            )
        if self.pad_token_id is None:
            raise ValueError(
                f"Unable to infer attention mask for {self.__class__.__name__}, because pad_token_id is None."
            )
        return input_ids.ne(self.pad_token_id).int()

    def remove_special_tokens(
        self, output: Tensor, attention_mask: Tensor | None = None, input_ids: Tensor | None = None
    ) -> Tuple[Tensor, Tensor, Tensor]:
        r"""
        Remove special tokens and clean up model outputs using attention masks.

        Processes model outputs by removing special tokens that were added during tokenization
        and applies attention masking to zero out padding positions. This comprehensive cleanup
        is essential for sequence-level tasks where predictions should only cover the actual
        input sequence, excluding special tokens and padding.

        The method performs:
        - BOS token removal: Strips the first token from all sequences
        - EOS token removal: Strips tokens after the EOS token and the EOS token itself
        - Attention mask adjustment: Updates mask to match the trimmed sequences
        - Output cleanup: Multiplies output by attention mask to zero out padding positions

        Args:
            output: Model output tensor of shape `(batch_size, seq_len, hidden_size)`.
            attention_mask: Attention mask of shape `(batch_size, seq_len)`.
            input_ids: Optional input token IDs of shape `(batch_size, seq_len)`.
                Used for precise EOS token location when available.

        Returns:
            Tuple containing:
                - output: Cleaned output tensor with special tokens removed and padding zeroed
                - attention_mask: Updated attention mask matching trimmed sequences
                - input_ids: Trimmed input IDs (if provided) or unchanged input

        Examples:
            >>> import torch
            >>> from multimolecule.models.configuration_utils import PreTrainedConfig
            >>> from multimolecule.modules.heads.generic import BasePredictionHead
            >>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=4))
            >>> output = torch.randn(1, 6, 4)
            >>> attention_mask = torch.tensor([[1, 1, 1, 1, 0, 0]])
            >>> input_ids = torch.tensor([[1, 10, 20, 2, 0, 0]])
            >>> new_out, new_mask, new_ids = head.remove_special_tokens(output, attention_mask, input_ids)
            >>> output.shape[1], new_out.shape[1]
            (6, 4)
            >>> new_mask
            tensor([[1, 1, 0, 0]])
        """
        if self.bos_token_id is not None:
            output = output[..., 1:, :]
            if attention_mask is not None:
                attention_mask = attention_mask[..., 1:]
            if input_ids is not None:
                input_ids = input_ids[..., 1:]
        if self.eos_token_id is not None:
            if input_ids is not None:
                eos_mask = input_ids.ne(self.eos_token_id).to(output.device)
                if isinstance(input_ids, Tensor):
                    input_ids.masked_fill_(~eos_mask, self.pad_token_id or 0)
                if isinstance(eos_mask, NestedTensor):
                    eos_mask = eos_mask.tensor
                input_ids = input_ids[..., :-1]
            elif attention_mask is not None:
                last_valid_indices = attention_mask.sum(dim=-1) - 1
                seq_length = attention_mask.size(-1)
                eos_mask = torch.arange(seq_length, device=output.device) != last_valid_indices.unsqueeze(1)
            else:
                raise ValueError("Unable to remove EOS tokens because input_ids and attention_mask are both None")
            output = (output * eos_mask.unsqueeze(-1))[..., :-1, :]
            if attention_mask is not None:
                attention_mask = (attention_mask * eos_mask)[..., :-1]
        if attention_mask is not None:
            output = output * attention_mask.unsqueeze(-1)
        return output, attention_mask, input_ids

    def remove_special_tokens_2d(
        self, output: Tensor, attention_mask: Tensor | None = None, input_ids: Tensor | None = None
    ) -> Tuple[Tensor, Tensor, Tensor]:
        r"""
        Remove special tokens from 2D outputs like contact maps or pairwise interaction matrices.

        Extends `remove_special_tokens` to handle 2D outputs where both sequence dimensions
        need special token removal. This is crucial for contact prediction and structure
        analysis tasks where the output represents pairwise relationships between residues.

        The method removes:
        - BOS tokens: Strips first row and column from the 2D output
        - EOS tokens: Strips rows/columns after EOS positions and the EOS positions themselves
        - Updates attention mask: Creates 2D mask from 1D sequence mask

        Args:
            output: 2D model output of shape `(batch_size, seq_len, seq_len, channels)`.
            attention_mask: 1D attention mask of shape `(batch_size, seq_len)`.
            input_ids: Optional input token IDs of shape `(batch_size, seq_len)`.

        Returns:
            Tuple containing:
                - output: Trimmed 2D output with special tokens removed from both dimensions
                - attention_mask: 2D attention mask of shape `(batch_size, new_len, new_len)`
                - input_ids: Trimmed input IDs (if provided) or unchanged input

        Examples:
            >>> import torch
            >>> from multimolecule.models.configuration_utils import PreTrainedConfig
            >>> from multimolecule.modules.heads.generic import BasePredictionHead
            >>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=4))
            >>> output = torch.randn(1, 5, 5, 1)
            >>> input_ids = torch.tensor([[1, 10, 20, 2, 0]])
            >>> attention_mask = torch.tensor([[1, 1, 1, 1, 0]])
            >>> new_out, new_mask, new_ids = head.remove_special_tokens_2d(output, attention_mask, input_ids)
            >>> output.shape, new_out.shape
            (torch.Size([1, 5, 5, 1]), torch.Size([1, 3, 3, 1]))
            >>> new_mask.shape
            torch.Size([1, 3, 3])
        """
        if self.bos_token_id is not None:
            output = output[..., 1:, 1:, :]
            if attention_mask is not None:
                attention_mask = attention_mask[..., 1:]
            if input_ids is not None:
                input_ids = input_ids[..., 1:]
        if self.eos_token_id is not None:
            if input_ids is not None:
                eos_mask = input_ids.ne(self.eos_token_id).to(output.device)
                if isinstance(input_ids, Tensor):
                    input_ids.masked_fill_(~eos_mask, self.pad_token_id or 0)
                if isinstance(eos_mask, NestedTensor):
                    eos_mask = eos_mask.tensor
                input_ids = input_ids[..., :-1]
            elif attention_mask is not None:
                last_valid_indices = attention_mask.sum(dim=-1) - 1
                seq_length = attention_mask.size(-1)
                eos_mask = torch.arange(seq_length, device=output.device) != last_valid_indices.unsqueeze(1)
            else:
                raise ValueError("Unable to remove EOS tokens because input_ids and attention_mask are both None")
            if attention_mask is not None:
                attention_mask = (attention_mask * eos_mask)[..., :-1]
            eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
            output = (output * eos_mask.unsqueeze(-1))[..., :-1, :-1, :]
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
            output = output * attention_mask.unsqueeze(-1)
        return output, attention_mask, input_ids

    @staticmethod
    def symmetrize(x: Tensor) -> Tensor:
        r"""
        Make output symmetric by averaging the tensor with its transpose.

        Args:
            x: Input tensor of shape (batch_size, seq_len, seq_len, channels).

        Returns:
            Symmetric tensor with the same shape as input.

        Examples:
            >>> import torch
            >>> from multimolecule.modules.heads.generic import BasePredictionHead
            >>> x = torch.tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]])
            >>> x.squeeze(-1)
            tensor([[[1., 2.],
                     [3., 4.]]])
            >>> symmetric = BasePredictionHead.symmetrize(x)
            >>> symmetric.squeeze(-1)
            tensor([[[1.0000, 2.5000],
                     [2.5000, 4.0000]]])
            >>> torch.allclose(symmetric, symmetric.transpose(1, 2))
            True
        """
        return (x + x.transpose(1, 2)) / 2

    @staticmethod
    def average_product_correct(x: Tensor) -> Tensor:
        r"""Perform Average Product Correction (APC) to remove systematic biases from contact maps.

        APC removes row and column biases that arise from varying residue frequencies and
        structural preferences in molecular contact maps. It subtracts the expected contact
        probability based on marginal frequencies to reveal genuine structural interactions.

        The correction formula: `corrected = original - (row_sums × col_sums) / total_sum`

        This is essential for accurate contact prediction across DNA, RNA, and protein structures.

        Args:
            x: Contact map tensor of shape `(batch_size, seq_len, seq_len, channels)`

        Returns:
            Bias-corrected contact map with the same shape as input

        Note:
            This correction removes spurious correlations caused by sequence composition bias,
            making genuine molecular contacts stand out more clearly from background noise.

        Examples:
            >>> import torch
            >>> from multimolecule.modules.heads.generic import BasePredictionHead
            >>> x = torch.tensor([[[[0.8, 0.6], [0.7, 0.5]], [[0.6, 0.4], [0.5, 0.3]]]])
            >>> x.squeeze(-1)
            tensor([[[[0.8000, 0.6000],
                      [0.7000, 0.5000]],
            <BLANKLINE>
                     [[0.6000, 0.4000],
                      [0.5000, 0.3000]]]])
            >>> corrected = BasePredictionHead.average_product_correct(x)
            >>> corrected.squeeze(-1)
            tensor([[[[-0.0077, -0.0111],
                      [ 0.0077,  0.0111]],
            <BLANKLINE>
                     [[ 0.0077,  0.0111],
                      [-0.0077, -0.0111]]]])
            >>> row_sums = corrected.sum(dim=2).squeeze()
            >>> col_sums = corrected.sum(dim=1).squeeze()
            >>> torch.allclose(row_sums, torch.tensor([[0.0, 0.0], [-0.0, -0.0]]), atol=1e-6)
            True
            >>> torch.allclose(col_sums, torch.tensor([[0.0, 0.0], [-0.0, -0.0]]), atol=1e-6)
            True
        """
        return x - x.sum(1, keepdims=True) * x.sum(2, keepdims=True) / x.sum((1, 2), keepdims=True)

output_name instance-attribute

Python
output_name: str | None

The default output to use for the head.

require_attentions class-attribute instance-attribute

Python
require_attentions: bool = False

Whether the head requires attentions from the model.

config instance-attribute

Python
config: HeadConfig = head_config

The configuration object for the head.

bos_token_id class-attribute instance-attribute

Python
bos_token_id: int | None = bos_token_id

The ID of the beginning-of-sequence token. Usually is an alias of cls_token_id.

eos_token_id class-attribute instance-attribute

Python
eos_token_id: int | None = eos_token_id

The ID of the end-of-sequence token. In rare cases, it is an alias of sep_token_id.

pad_token_id class-attribute instance-attribute

Python
pad_token_id: int | None = pad_token_id

The ID of the padding token.

num_labels instance-attribute

Python
num_labels: int = num_labels

Number of labels for the head.

get_attention_mask

Python
get_attention_mask(input_ids: NestedTensor | Tensor) -> Tensor

Generate attention mask from input IDs or extract from NestedTensor.

Creates a binary attention mask indicating which tokens should be attended to (1) and which should be ignored (0, typically padding tokens). For NestedTensor inputs, extracts the pre-computed mask. For regular tensors, compares against pad_token_id.

参数:

名称 类型 描述 默认
input_ids
NestedTensor | Tensor

Input token IDs as either a NestedTensor with embedded mask or a regular Tensor of shape (batch_size, seq_len).

必需

返回:

类型 描述
Tensor

Binary attention mask of shape (batch_size, seq_len) where 1 indicates

Tensor

tokens to attend to and 0 indicates tokens to ignore.

引发:

类型 描述
ValueError

If input_ids is None or if pad_token_id is None when needed for regular Tensor inputs.

示例:

Python Console Session
1
2
3
4
5
6
7
8
9
>>> import torch
>>> from multimolecule.models.configuration_utils import PreTrainedConfig
>>> from multimolecule.modules.heads.generic import BasePredictionHead
>>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=128))
>>> input_ids = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 0, 0, 0]])
>>> mask = head.get_attention_mask(input_ids)
>>> mask
tensor([[1, 1, 1, 0, 0],
        [1, 1, 0, 0, 0]], dtype=torch.int32)
源代码位于: multimolecule/modules/heads/generic.py
Python
def get_attention_mask(self, input_ids: NestedTensor | Tensor) -> Tensor:
    r"""
    Generate attention mask from input IDs or extract from NestedTensor.

    Creates a binary attention mask indicating which tokens should be attended to (1)
    and which should be ignored (0, typically padding tokens). For NestedTensor inputs,
    extracts the pre-computed mask. For regular tensors, compares against pad_token_id.

    Args:
        input_ids: Input token IDs as either a NestedTensor with embedded mask
            or a regular Tensor of shape `(batch_size, seq_len)`.

    Returns:
        Binary attention mask of shape `(batch_size, seq_len)` where 1 indicates
        tokens to attend to and 0 indicates tokens to ignore.

    Raises:
        ValueError: If input_ids is None or if pad_token_id is None when needed
            for regular Tensor inputs.

    Examples:
        >>> import torch
        >>> from multimolecule.models.configuration_utils import PreTrainedConfig
        >>> from multimolecule.modules.heads.generic import BasePredictionHead
        >>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=128))
        >>> input_ids = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 0, 0, 0]])
        >>> mask = head.get_attention_mask(input_ids)
        >>> mask
        tensor([[1, 1, 1, 0, 0],
                [1, 1, 0, 0, 0]], dtype=torch.int32)
    """
    if isinstance(input_ids, NestedTensor):
        return input_ids.mask
    if input_ids is None:
        raise ValueError(
            f"Unable to infer attention mask for {self.__class__.__name__}, because input_ids is None."
        )
    if self.pad_token_id is None:
        raise ValueError(
            f"Unable to infer attention mask for {self.__class__.__name__}, because pad_token_id is None."
        )
    return input_ids.ne(self.pad_token_id).int()

remove_special_tokens

Python
remove_special_tokens(output: Tensor, attention_mask: Tensor | None = None, input_ids: Tensor | None = None) -> Tuple[Tensor, Tensor, Tensor]

Remove special tokens and clean up model outputs using attention masks.

Processes model outputs by removing special tokens that were added during tokenization and applies attention masking to zero out padding positions. This comprehensive cleanup is essential for sequence-level tasks where predictions should only cover the actual input sequence, excluding special tokens and padding.

The method performs: - BOS token removal: Strips the first token from all sequences - EOS token removal: Strips tokens after the EOS token and the EOS token itself - Attention mask adjustment: Updates mask to match the trimmed sequences - Output cleanup: Multiplies output by attention mask to zero out padding positions

参数:

名称 类型 描述 默认
output
Tensor

Model output tensor of shape (batch_size, seq_len, hidden_size).

必需
attention_mask
Tensor | None

Attention mask of shape (batch_size, seq_len).

None
input_ids
Tensor | None

Optional input token IDs of shape (batch_size, seq_len). Used for precise EOS token location when available.

None

返回:

类型 描述
Tuple[Tensor, Tensor, Tensor]

Tuple containing: - output: Cleaned output tensor with special tokens removed and padding zeroed - attention_mask: Updated attention mask matching trimmed sequences - input_ids: Trimmed input IDs (if provided) or unchanged input

示例:

Python Console Session
>>> import torch
>>> from multimolecule.models.configuration_utils import PreTrainedConfig
>>> from multimolecule.modules.heads.generic import BasePredictionHead
>>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=4))
>>> output = torch.randn(1, 6, 4)
>>> attention_mask = torch.tensor([[1, 1, 1, 1, 0, 0]])
>>> input_ids = torch.tensor([[1, 10, 20, 2, 0, 0]])
>>> new_out, new_mask, new_ids = head.remove_special_tokens(output, attention_mask, input_ids)
>>> output.shape[1], new_out.shape[1]
(6, 4)
>>> new_mask
tensor([[1, 1, 0, 0]])
源代码位于: multimolecule/modules/heads/generic.py
Python
def remove_special_tokens(
    self, output: Tensor, attention_mask: Tensor | None = None, input_ids: Tensor | None = None
) -> Tuple[Tensor, Tensor, Tensor]:
    r"""
    Remove special tokens and clean up model outputs using attention masks.

    Processes model outputs by removing special tokens that were added during tokenization
    and applies attention masking to zero out padding positions. This comprehensive cleanup
    is essential for sequence-level tasks where predictions should only cover the actual
    input sequence, excluding special tokens and padding.

    The method performs:
    - BOS token removal: Strips the first token from all sequences
    - EOS token removal: Strips tokens after the EOS token and the EOS token itself
    - Attention mask adjustment: Updates mask to match the trimmed sequences
    - Output cleanup: Multiplies output by attention mask to zero out padding positions

    Args:
        output: Model output tensor of shape `(batch_size, seq_len, hidden_size)`.
        attention_mask: Attention mask of shape `(batch_size, seq_len)`.
        input_ids: Optional input token IDs of shape `(batch_size, seq_len)`.
            Used for precise EOS token location when available.

    Returns:
        Tuple containing:
            - output: Cleaned output tensor with special tokens removed and padding zeroed
            - attention_mask: Updated attention mask matching trimmed sequences
            - input_ids: Trimmed input IDs (if provided) or unchanged input

    Examples:
        >>> import torch
        >>> from multimolecule.models.configuration_utils import PreTrainedConfig
        >>> from multimolecule.modules.heads.generic import BasePredictionHead
        >>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=4))
        >>> output = torch.randn(1, 6, 4)
        >>> attention_mask = torch.tensor([[1, 1, 1, 1, 0, 0]])
        >>> input_ids = torch.tensor([[1, 10, 20, 2, 0, 0]])
        >>> new_out, new_mask, new_ids = head.remove_special_tokens(output, attention_mask, input_ids)
        >>> output.shape[1], new_out.shape[1]
        (6, 4)
        >>> new_mask
        tensor([[1, 1, 0, 0]])
    """
    if self.bos_token_id is not None:
        output = output[..., 1:, :]
        if attention_mask is not None:
            attention_mask = attention_mask[..., 1:]
        if input_ids is not None:
            input_ids = input_ids[..., 1:]
    if self.eos_token_id is not None:
        if input_ids is not None:
            eos_mask = input_ids.ne(self.eos_token_id).to(output.device)
            if isinstance(input_ids, Tensor):
                input_ids.masked_fill_(~eos_mask, self.pad_token_id or 0)
            if isinstance(eos_mask, NestedTensor):
                eos_mask = eos_mask.tensor
            input_ids = input_ids[..., :-1]
        elif attention_mask is not None:
            last_valid_indices = attention_mask.sum(dim=-1) - 1
            seq_length = attention_mask.size(-1)
            eos_mask = torch.arange(seq_length, device=output.device) != last_valid_indices.unsqueeze(1)
        else:
            raise ValueError("Unable to remove EOS tokens because input_ids and attention_mask are both None")
        output = (output * eos_mask.unsqueeze(-1))[..., :-1, :]
        if attention_mask is not None:
            attention_mask = (attention_mask * eos_mask)[..., :-1]
    if attention_mask is not None:
        output = output * attention_mask.unsqueeze(-1)
    return output, attention_mask, input_ids

remove_special_tokens_2d

Python
remove_special_tokens_2d(output: Tensor, attention_mask: Tensor | None = None, input_ids: Tensor | None = None) -> Tuple[Tensor, Tensor, Tensor]

Remove special tokens from 2D outputs like contact maps or pairwise interaction matrices.

Extends remove_special_tokens to handle 2D outputs where both sequence dimensions need special token removal. This is crucial for contact prediction and structure analysis tasks where the output represents pairwise relationships between residues.

The method removes: - BOS tokens: Strips first row and column from the 2D output - EOS tokens: Strips rows/columns after EOS positions and the EOS positions themselves - Updates attention mask: Creates 2D mask from 1D sequence mask

参数:

名称 类型 描述 默认
output
Tensor

2D model output of shape (batch_size, seq_len, seq_len, channels).

必需
attention_mask
Tensor | None

1D attention mask of shape (batch_size, seq_len).

None
input_ids
Tensor | None

Optional input token IDs of shape (batch_size, seq_len).

None

返回:

类型 描述
Tuple[Tensor, Tensor, Tensor]

Tuple containing: - output: Trimmed 2D output with special tokens removed from both dimensions - attention_mask: 2D attention mask of shape (batch_size, new_len, new_len) - input_ids: Trimmed input IDs (if provided) or unchanged input

示例:

Python Console Session
>>> import torch
>>> from multimolecule.models.configuration_utils import PreTrainedConfig
>>> from multimolecule.modules.heads.generic import BasePredictionHead
>>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=4))
>>> output = torch.randn(1, 5, 5, 1)
>>> input_ids = torch.tensor([[1, 10, 20, 2, 0]])
>>> attention_mask = torch.tensor([[1, 1, 1, 1, 0]])
>>> new_out, new_mask, new_ids = head.remove_special_tokens_2d(output, attention_mask, input_ids)
>>> output.shape, new_out.shape
(torch.Size([1, 5, 5, 1]), torch.Size([1, 3, 3, 1]))
>>> new_mask.shape
torch.Size([1, 3, 3])
源代码位于: multimolecule/modules/heads/generic.py
Python
def remove_special_tokens_2d(
    self, output: Tensor, attention_mask: Tensor | None = None, input_ids: Tensor | None = None
) -> Tuple[Tensor, Tensor, Tensor]:
    r"""
    Remove special tokens from 2D outputs like contact maps or pairwise interaction matrices.

    Extends `remove_special_tokens` to handle 2D outputs where both sequence dimensions
    need special token removal. This is crucial for contact prediction and structure
    analysis tasks where the output represents pairwise relationships between residues.

    The method removes:
    - BOS tokens: Strips first row and column from the 2D output
    - EOS tokens: Strips rows/columns after EOS positions and the EOS positions themselves
    - Updates attention mask: Creates 2D mask from 1D sequence mask

    Args:
        output: 2D model output of shape `(batch_size, seq_len, seq_len, channels)`.
        attention_mask: 1D attention mask of shape `(batch_size, seq_len)`.
        input_ids: Optional input token IDs of shape `(batch_size, seq_len)`.

    Returns:
        Tuple containing:
            - output: Trimmed 2D output with special tokens removed from both dimensions
            - attention_mask: 2D attention mask of shape `(batch_size, new_len, new_len)`
            - input_ids: Trimmed input IDs (if provided) or unchanged input

    Examples:
        >>> import torch
        >>> from multimolecule.models.configuration_utils import PreTrainedConfig
        >>> from multimolecule.modules.heads.generic import BasePredictionHead
        >>> head = BasePredictionHead(PreTrainedConfig(num_labels=2, hidden_size=4))
        >>> output = torch.randn(1, 5, 5, 1)
        >>> input_ids = torch.tensor([[1, 10, 20, 2, 0]])
        >>> attention_mask = torch.tensor([[1, 1, 1, 1, 0]])
        >>> new_out, new_mask, new_ids = head.remove_special_tokens_2d(output, attention_mask, input_ids)
        >>> output.shape, new_out.shape
        (torch.Size([1, 5, 5, 1]), torch.Size([1, 3, 3, 1]))
        >>> new_mask.shape
        torch.Size([1, 3, 3])
    """
    if self.bos_token_id is not None:
        output = output[..., 1:, 1:, :]
        if attention_mask is not None:
            attention_mask = attention_mask[..., 1:]
        if input_ids is not None:
            input_ids = input_ids[..., 1:]
    if self.eos_token_id is not None:
        if input_ids is not None:
            eos_mask = input_ids.ne(self.eos_token_id).to(output.device)
            if isinstance(input_ids, Tensor):
                input_ids.masked_fill_(~eos_mask, self.pad_token_id or 0)
            if isinstance(eos_mask, NestedTensor):
                eos_mask = eos_mask.tensor
            input_ids = input_ids[..., :-1]
        elif attention_mask is not None:
            last_valid_indices = attention_mask.sum(dim=-1) - 1
            seq_length = attention_mask.size(-1)
            eos_mask = torch.arange(seq_length, device=output.device) != last_valid_indices.unsqueeze(1)
        else:
            raise ValueError("Unable to remove EOS tokens because input_ids and attention_mask are both None")
        if attention_mask is not None:
            attention_mask = (attention_mask * eos_mask)[..., :-1]
        eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
        output = (output * eos_mask.unsqueeze(-1))[..., :-1, :-1, :]
    if attention_mask is not None:
        attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
        output = output * attention_mask.unsqueeze(-1)
    return output, attention_mask, input_ids

symmetrize staticmethod

Python
symmetrize(x: Tensor) -> Tensor

Make output symmetric by averaging the tensor with its transpose.

参数:

名称 类型 描述 默认
x
Tensor

Input tensor of shape (batch_size, seq_len, seq_len, channels).

必需

返回:

类型 描述
Tensor

Symmetric tensor with the same shape as input.

示例:

Python Console Session
>>> import torch
>>> from multimolecule.modules.heads.generic import BasePredictionHead
>>> x = torch.tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]])
>>> x.squeeze(-1)
tensor([[[1., 2.],
         [3., 4.]]])
>>> symmetric = BasePredictionHead.symmetrize(x)
>>> symmetric.squeeze(-1)
tensor([[[1.0000, 2.5000],
         [2.5000, 4.0000]]])
>>> torch.allclose(symmetric, symmetric.transpose(1, 2))
True
源代码位于: multimolecule/modules/heads/generic.py
Python
@staticmethod
def symmetrize(x: Tensor) -> Tensor:
    r"""
    Make output symmetric by averaging the tensor with its transpose.

    Args:
        x: Input tensor of shape (batch_size, seq_len, seq_len, channels).

    Returns:
        Symmetric tensor with the same shape as input.

    Examples:
        >>> import torch
        >>> from multimolecule.modules.heads.generic import BasePredictionHead
        >>> x = torch.tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]])
        >>> x.squeeze(-1)
        tensor([[[1., 2.],
                 [3., 4.]]])
        >>> symmetric = BasePredictionHead.symmetrize(x)
        >>> symmetric.squeeze(-1)
        tensor([[[1.0000, 2.5000],
                 [2.5000, 4.0000]]])
        >>> torch.allclose(symmetric, symmetric.transpose(1, 2))
        True
    """
    return (x + x.transpose(1, 2)) / 2

average_product_correct staticmethod

Python
average_product_correct(x: Tensor) -> Tensor

Perform Average Product Correction (APC) to remove systematic biases from contact maps.

APC removes row and column biases that arise from varying residue frequencies and structural preferences in molecular contact maps. It subtracts the expected contact probability based on marginal frequencies to reveal genuine structural interactions.

The correction formula: corrected = original - (row_sums × col_sums) / total_sum

This is essential for accurate contact prediction across DNA, RNA, and protein structures.

参数:

名称 类型 描述 默认
x
Tensor

Contact map tensor of shape (batch_size, seq_len, seq_len, channels)

必需

返回:

类型 描述
Tensor

Bias-corrected contact map with the same shape as input

Note

This correction removes spurious correlations caused by sequence composition bias, making genuine molecular contacts stand out more clearly from background noise.

示例:

Python Console Session
>>> import torch
>>> from multimolecule.modules.heads.generic import BasePredictionHead
>>> x = torch.tensor([[[[0.8, 0.6], [0.7, 0.5]], [[0.6, 0.4], [0.5, 0.3]]]])
>>> x.squeeze(-1)
tensor([[[[0.8000, 0.6000],
          [0.7000, 0.5000]],

         [[0.6000, 0.4000],
          [0.5000, 0.3000]]]])
>>> corrected = BasePredictionHead.average_product_correct(x)
>>> corrected.squeeze(-1)
tensor([[[[-0.0077, -0.0111],
          [ 0.0077,  0.0111]],

         [[ 0.0077,  0.0111],
          [-0.0077, -0.0111]]]])
>>> row_sums = corrected.sum(dim=2).squeeze()
>>> col_sums = corrected.sum(dim=1).squeeze()
>>> torch.allclose(row_sums, torch.tensor([[0.0, 0.0], [-0.0, -0.0]]), atol=1e-6)
True
>>> torch.allclose(col_sums, torch.tensor([[0.0, 0.0], [-0.0, -0.0]]), atol=1e-6)
True
源代码位于: multimolecule/modules/heads/generic.py
Python
@staticmethod
def average_product_correct(x: Tensor) -> Tensor:
    r"""Perform Average Product Correction (APC) to remove systematic biases from contact maps.

    APC removes row and column biases that arise from varying residue frequencies and
    structural preferences in molecular contact maps. It subtracts the expected contact
    probability based on marginal frequencies to reveal genuine structural interactions.

    The correction formula: `corrected = original - (row_sums × col_sums) / total_sum`

    This is essential for accurate contact prediction across DNA, RNA, and protein structures.

    Args:
        x: Contact map tensor of shape `(batch_size, seq_len, seq_len, channels)`

    Returns:
        Bias-corrected contact map with the same shape as input

    Note:
        This correction removes spurious correlations caused by sequence composition bias,
        making genuine molecular contacts stand out more clearly from background noise.

    Examples:
        >>> import torch
        >>> from multimolecule.modules.heads.generic import BasePredictionHead
        >>> x = torch.tensor([[[[0.8, 0.6], [0.7, 0.5]], [[0.6, 0.4], [0.5, 0.3]]]])
        >>> x.squeeze(-1)
        tensor([[[[0.8000, 0.6000],
                  [0.7000, 0.5000]],
        <BLANKLINE>
                 [[0.6000, 0.4000],
                  [0.5000, 0.3000]]]])
        >>> corrected = BasePredictionHead.average_product_correct(x)
        >>> corrected.squeeze(-1)
        tensor([[[[-0.0077, -0.0111],
                  [ 0.0077,  0.0111]],
        <BLANKLINE>
                 [[ 0.0077,  0.0111],
                  [-0.0077, -0.0111]]]])
        >>> row_sums = corrected.sum(dim=2).squeeze()
        >>> col_sums = corrected.sum(dim=1).squeeze()
        >>> torch.allclose(row_sums, torch.tensor([[0.0, 0.0], [-0.0, -0.0]]), atol=1e-6)
        True
        >>> torch.allclose(col_sums, torch.tensor([[0.0, 0.0], [-0.0, -0.0]]), atol=1e-6)
        True
    """
    return x - x.sum(1, keepdims=True) * x.sum(2, keepdims=True) / x.sum((1, 2), keepdims=True)

PredictionHead

Bases: BasePredictionHead

Head for all-level of tasks.

参数:

名称 类型 描述 默认

config

PreTrainedConfig

The configuration object for the model.

必需

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/modules/heads/generic.py
Python
class PredictionHead(BasePredictionHead):
    r"""
    Head for all-level of tasks.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.dropout = nn.Dropout(self.config.dropout)
        self.transform = HEAD_TRANSFORMS_HF.build(self.config)
        self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias)
        self.activation = ACT2FN[self.config.act] if self.config.act is not None else None
        self.criterion = CRITERIONS.build(self.config)

    def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput:
        r"""
        Forward pass of the PredictionHead.

        Args:
            embeddings: The embeddings to be passed through the head.
            labels: The labels for the head.
        """
        if kwargs:
            warn(
                f"The following arguments are not applicable to {self.__class__.__name__} "
                f"and will be ignored: {kwargs.keys()}"
            )
        output = self.dropout(embeddings)
        output = self.transform(output)
        output = self.decoder(output)
        if self.activation is not None:
            output = self.activation(output)
        if labels is not None:
            if isinstance(labels, NestedTensor):
                if isinstance(output, Tensor):
                    output = labels.nested_like(output, strict=False)
                return HeadOutput(output, self.criterion(output.concat, labels.concat))
            return HeadOutput(output, self.criterion(output, labels))
        return HeadOutput(output)

forward

Python
forward(embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput

Forward pass of the PredictionHead.

参数:

名称 类型 描述 默认
embeddings
Tensor

The embeddings to be passed through the head.

必需
labels
Tensor | None

The labels for the head.

必需
源代码位于: multimolecule/modules/heads/generic.py
Python
def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput:
    r"""
    Forward pass of the PredictionHead.

    Args:
        embeddings: The embeddings to be passed through the head.
        labels: The labels for the head.
    """
    if kwargs:
        warn(
            f"The following arguments are not applicable to {self.__class__.__name__} "
            f"and will be ignored: {kwargs.keys()}"
        )
    output = self.dropout(embeddings)
    output = self.transform(output)
    output = self.decoder(output)
    if self.activation is not None:
        output = self.activation(output)
    if labels is not None:
        if isinstance(labels, NestedTensor):
            if isinstance(output, Tensor):
                output = labels.nested_like(output, strict=False)
            return HeadOutput(output, self.criterion(output.concat, labels.concat))
        return HeadOutput(output, self.criterion(output, labels))
    return HeadOutput(output)

multimolecule.modules.heads.output

HeadOutput dataclass

Bases: ModelOutput

Output of a prediction head.

参数:

名称 类型 描述 默认

logits

FloatTensor

The prediction logits from the head.

必需

loss

FloatTensor | None

The loss from the head. Defaults to None.

None
Source code in multimolecule/modules/heads/output.py
Python
@dataclass
class HeadOutput(ModelOutput):
    r"""
    Output of a prediction head.

    Args:
        logits: The prediction logits from the head.
        loss: The loss from the head.
            Defaults to None.
    """

    logits: FloatTensor
    loss: FloatTensor | None = None