Skip to content

heads

heads provide a collection of pre-defined prediction heads.

heads take in either a ModelOutupt, a dict, or a tuple as input. It automatically looks for the model output required for prediction and processes it accordingly.

Some prediction heads may require additional information, such as the attention_mask or the input_ids, like ContactPredictionHead. These additional arguments can be passed in as arguments/keyword arguments.

Note that heads use the same ModelOutupt conventions as the 🤗 Transformers. If the model output is a tuple, we consider the first element as the pooler_output, the second element as the last_hidden_state, and the last element as the attention_map. It is the user’s responsibility to ensure that the model output is correctly formatted.

If the model output is a ModelOutupt or a dict, the heads will look for the HeadConfig.output_name from the model output. You can specify the output_name in the HeadConfig to ensure that the heads can correctly locate the required tensor.

multimolecule.module.heads.config

HeadConfig

Bases: BaseHeadConfig

Configuration class for a prediction head.

Parameters:

Name Type Description Default

num_labels

Number of labels to use in the last layer added to the model, typically for a classification task.

Head should look for Config.num_labels if is None.

required

problem_type

Problem type for XxxForYyyPrediction models. Can be one of "binary", "regression", "multiclass" or "multilabel".

Head should look for Config.problem_type if is None.

required

hidden_size

Dimensionality of the encoder layers and the pooler layer.

Head should look for Config.hidden_size if is None.

required

dropout

The dropout ratio for the hidden states.

required

transform

The transform operation applied to hidden states.

required

transform_act

The activation function of transform applied to hidden states.

required

bias

Whether to apply bias to the final prediction layer.

required

act

The activation function of the final prediction output.

required

layer_norm_eps

The epsilon used by the layer normalization layers.

required

output_name

The name of the tensor required in model outputs.

If is None, will use the default output name of the corresponding head.

required

type

The type of the head in the model.

This is used by [MultiMoleculeModel][multimolecule.MultiMoleculeModel] to construct heads.

required
Source code in multimolecule/module/heads/config.py
Python
class HeadConfig(BaseHeadConfig):
    r"""
    Configuration class for a prediction head.

    Args:
        num_labels:
            Number of labels to use in the last layer added to the model, typically for a classification task.

            Head should look for [`Config.num_labels`][multimolecule.PreTrainedConfig] if is `None`.
        problem_type:
            Problem type for `XxxForYyyPrediction` models. Can be one of `"binary"`, `"regression"`,
            `"multiclass"` or `"multilabel"`.

            Head should look for [`Config.problem_type`][multimolecule.PreTrainedConfig] if is `None`.
        hidden_size:
            Dimensionality of the encoder layers and the pooler layer.

            Head should look for [`Config.hidden_size`][multimolecule.PreTrainedConfig] if is `None`.
        dropout:
            The dropout ratio for the hidden states.
        transform:
            The transform operation applied to hidden states.
        transform_act:
            The activation function of transform applied to hidden states.
        bias:
            Whether to apply bias to the final prediction layer.
        act:
            The activation function of the final prediction output.
        layer_norm_eps:
            The epsilon used by the layer normalization layers.
        output_name:
            The name of the tensor required in model outputs.

            If is `None`, will use the default output name of the corresponding head.
        type:
            The type of the head in the model.

            This is used by [`MultiMoleculeModel`][multimolecule.MultiMoleculeModel] to construct heads.
    """

    num_labels: Optional[int] = None
    problem_type: Optional[str] = None
    hidden_size: Optional[int] = None
    dropout: float = 0.0
    transform: Optional[str] = None
    transform_act: Optional[str] = "gelu"
    bias: bool = True
    act: Optional[str] = None
    layer_norm_eps: float = 1e-12
    output_name: Optional[str] = None
    type: Optional[str] = None

MaskedLMHeadConfig

Bases: BaseHeadConfig

Configuration class for a Masked Language Modeling head.

Parameters:

Name Type Description Default

hidden_size

Dimensionality of the encoder layers and the pooler layer.

Head should look for Config.hidden_size if is None.

required

dropout

The dropout ratio for the hidden states.

required

transform

The transform operation applied to hidden states.

required

transform_act

The activation function of transform applied to hidden states.

required

bias

Whether to apply bias to the final prediction layer.

required

act

The activation function of the final prediction output.

required

layer_norm_eps

The epsilon used by the layer normalization layers.

required

output_name

The name of the tensor required in model outputs.

If is None, will use the default output name of the corresponding head.

required
Source code in multimolecule/module/heads/config.py
Python
class MaskedLMHeadConfig(BaseHeadConfig):
    r"""
    Configuration class for a Masked Language Modeling head.

    Args:
        hidden_size:
            Dimensionality of the encoder layers and the pooler layer.

            Head should look for [`Config.hidden_size`][multimolecule.PreTrainedConfig] if is `None`.
        dropout:
            The dropout ratio for the hidden states.
        transform:
            The transform operation applied to hidden states.
        transform_act:
            The activation function of transform applied to hidden states.
        bias:
            Whether to apply bias to the final prediction layer.
        act:
            The activation function of the final prediction output.
        layer_norm_eps:
            The epsilon used by the layer normalization layers.
        output_name:
            The name of the tensor required in model outputs.

            If is `None`, will use the default output name of the corresponding head.
    """

    hidden_size: Optional[int] = None
    dropout: float = 0.0
    transform: Optional[str] = "nonlinear"
    transform_act: Optional[str] = "gelu"
    bias: bool = True
    act: Optional[str] = None
    layer_norm_eps: float = 1e-12
    output_name: Optional[str] = None

multimolecule.module.heads.sequence

SequencePredictionHead

Bases: PredictionHead

Head for tasks in sequence-level.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/module/heads/sequence.py
Python
@HeadRegistry.register("sequence")
class SequencePredictionHead(PredictionHead):
    r"""
    Head for tasks in sequence-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "pooler_output"
    r"""The default output to use for the head."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        if head_config is not None and head_config.output_name is not None:
            self.output_name = head_config.output_name

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Tuple[Tensor, ...],
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the SequencePredictionHead.

        Args:
            outputs: The outputs of the model.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[1]
        return super().forward(output, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'pooler_output'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Tuple[Tensor, ...], labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the SequencePredictionHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Tuple[Tensor, ...]

The outputs of the model.

required
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/module/heads/sequence.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Tuple[Tensor, ...],
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the SequencePredictionHead.

    Args:
        outputs: The outputs of the model.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[1]
    return super().forward(output, labels, **kwargs)

multimolecule.module.heads.token

TokenPredictionHead

Bases: PredictionHead

Head for tasks in token-level.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/module/heads/token.py
Python
@HeadRegistry.token.register("single", default=True)
@TokenHeadRegistryHF.register("single", default=True)
class TokenPredictionHead(PredictionHead):
    r"""
    Head for tasks in token-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        if head_config is not None and head_config.output_name is not None:
            self.output_name = head_config.output_name

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the TokenPredictionHead.

        Args:
            outputs: The outputs of the model.
            attention_mask: The attention mask for the inputs.
            input_ids: The input ids for the inputs.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

        if attention_mask is None:
            attention_mask = self._get_attention_mask(input_ids)
        output = output * attention_mask.unsqueeze(-1)
        output, _, _ = self._remove_special_tokens(output, attention_mask, input_ids)

        return super().forward(output, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Tuple[Tensor, ...], attention_mask: Tensor | None = None, input_ids: NestedTensor | Tensor | None = None, labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the TokenPredictionHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Tuple[Tensor, ...]

The outputs of the model.

required
attention_mask
Tensor | None

The attention mask for the inputs.

None
input_ids
NestedTensor | Tensor | None

The input ids for the inputs.

None
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/module/heads/token.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Tuple[Tensor, ...],
    attention_mask: Tensor | None = None,
    input_ids: NestedTensor | Tensor | None = None,
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the TokenPredictionHead.

    Args:
        outputs: The outputs of the model.
        attention_mask: The attention mask for the inputs.
        input_ids: The input ids for the inputs.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[0]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

    if attention_mask is None:
        attention_mask = self._get_attention_mask(input_ids)
    output = output * attention_mask.unsqueeze(-1)
    output, _, _ = self._remove_special_tokens(output, attention_mask, input_ids)

    return super().forward(output, labels, **kwargs)

TokenKMerHead

Bases: PredictionHead

Head for tasks in token-level with kmer inputs.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/module/heads/token.py
Python
@HeadRegistry.register("token.kmer")
@TokenHeadRegistryHF.register("kmer")
class TokenKMerHead(PredictionHead):
    r"""
    Head for tasks in token-level with kmer inputs.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.nmers = config.nmers
        if head_config is not None and head_config.output_name is not None:
            self.output_name = head_config.output_name
        # Do not pass bos_token_id and eos_token_id to unfold_kmer_embeddings
        # As they will be removed in preprocess
        self.unfold_kmer_embeddings = partial(unfold_kmer_embeddings, nmers=self.nmers)

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the TokenKMerHead.

        Args:
            outputs: The outputs of the model.
            attention_mask: The attention mask for the inputs.
            input_ids: The input ids for the inputs.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

        if attention_mask is None:
            attention_mask = self._get_attention_mask(input_ids)
        output = output * attention_mask.unsqueeze(-1)
        output, attention_mask, _ = self._remove_special_tokens(output, attention_mask, input_ids)

        output = self.unfold_kmer_embeddings(output, attention_mask)
        return super().forward(output, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Tuple[Tensor, ...], attention_mask: Tensor | None = None, input_ids: NestedTensor | Tensor | None = None, labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the TokenKMerHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Tuple[Tensor, ...]

The outputs of the model.

required
attention_mask
Tensor | None

The attention mask for the inputs.

None
input_ids
NestedTensor | Tensor | None

The input ids for the inputs.

None
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/module/heads/token.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Tuple[Tensor, ...],
    attention_mask: Tensor | None = None,
    input_ids: NestedTensor | Tensor | None = None,
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the TokenKMerHead.

    Args:
        outputs: The outputs of the model.
        attention_mask: The attention mask for the inputs.
        input_ids: The input ids for the inputs.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[0]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

    if attention_mask is None:
        attention_mask = self._get_attention_mask(input_ids)
    output = output * attention_mask.unsqueeze(-1)
    output, attention_mask, _ = self._remove_special_tokens(output, attention_mask, input_ids)

    output = self.unfold_kmer_embeddings(output, attention_mask)
    return super().forward(output, labels, **kwargs)

multimolecule.module.heads.contact

ContactPredictionHead

Bases: PredictionHead

Head for tasks in contact-level.

Performs symmetrization, and average product correct.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/module/heads/contact.py
Python
@HeadRegistry.contact.register("attention")
class ContactPredictionHead(PredictionHead):
    r"""
    Head for tasks in contact-level.

    Performs symmetrization, and average product correct.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "attentions"
    r"""The default output to use for the head."""

    requires_attention: bool = True

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.config.hidden_size = config.num_hidden_layers * config.num_attention_heads
        num_layers = self.config.get("num_layers", 16)
        num_channels = self.config.get("num_channels", self.config.hidden_size // 10)  # type: ignore[operator]
        block = self.config.get("block", "auto")
        self.decoder = ResNet(
            num_layers=num_layers,
            hidden_size=self.config.hidden_size,  # type: ignore[arg-type]
            block=block,
            num_channels=num_channels,
            num_labels=self.num_labels,
        )
        if head_config is not None and head_config.output_name is not None:
            self.output_name = head_config.output_name

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Mapping | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the ContactPredictionHead.

        Args:
            outputs: The outputs of the model.
            attention_mask: The attention mask for the inputs.
            input_ids: The input ids for the inputs.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """

        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[-1]
        attentions = torch.stack(output, 1)

        # In the original model, attentions for padding tokens are completely zeroed out.
        # This makes no difference most of the time because the other tokens won't attend to them,
        # but it does for the contact prediction task, which takes attentions as input,
        # so we have to mimic that here.
        if attention_mask is None:
            attention_mask = self._get_attention_mask(input_ids)
        attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
        attentions = attentions * attention_mask[:, None, None, :, :]

        # remove cls token attentions
        if self.bos_token_id is not None:
            attentions = attentions[..., 1:, 1:]
            attention_mask = attention_mask[..., 1:]
            if input_ids is not None:
                input_ids = input_ids[..., 1:]
        # remove eos token attentions
        if self.eos_token_id is not None:
            if input_ids is not None:
                eos_mask = input_ids.ne(self.eos_token_id).to(attentions)
            else:
                last_valid_indices = attention_mask.sum(dim=-1)
                seq_length = attention_mask.size(-1)
                eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) == last_valid_indices
            eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
            attentions = attentions * eos_mask[:, None, None, :, :]
            attentions = attentions[..., :-1, :-1]

        # features: batch x channels x input_ids x input_ids (symmetric)
        batch_size, layers, heads, seqlen, _ = attentions.size()
        attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
        attentions = attentions.to(self.decoder.proj.weight.device)
        attentions = average_product_correct(symmetrize(attentions))
        attentions = attentions.permute(0, 2, 3, 1).squeeze(3)

        return super().forward(attentions, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'attentions'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Mapping | Tuple[Tensor, ...], attention_mask: Tensor | None = None, input_ids: NestedTensor | Tensor | None = None, labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the ContactPredictionHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Mapping | Tuple[Tensor, ...]

The outputs of the model.

required
attention_mask
Tensor | None

The attention mask for the inputs.

None
input_ids
NestedTensor | Tensor | None

The input ids for the inputs.

None
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/module/heads/contact.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Mapping | Tuple[Tensor, ...],
    attention_mask: Tensor | None = None,
    input_ids: NestedTensor | Tensor | None = None,
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the ContactPredictionHead.

    Args:
        outputs: The outputs of the model.
        attention_mask: The attention mask for the inputs.
        input_ids: The input ids for the inputs.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """

    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[-1]
    attentions = torch.stack(output, 1)

    # In the original model, attentions for padding tokens are completely zeroed out.
    # This makes no difference most of the time because the other tokens won't attend to them,
    # but it does for the contact prediction task, which takes attentions as input,
    # so we have to mimic that here.
    if attention_mask is None:
        attention_mask = self._get_attention_mask(input_ids)
    attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
    attentions = attentions * attention_mask[:, None, None, :, :]

    # remove cls token attentions
    if self.bos_token_id is not None:
        attentions = attentions[..., 1:, 1:]
        attention_mask = attention_mask[..., 1:]
        if input_ids is not None:
            input_ids = input_ids[..., 1:]
    # remove eos token attentions
    if self.eos_token_id is not None:
        if input_ids is not None:
            eos_mask = input_ids.ne(self.eos_token_id).to(attentions)
        else:
            last_valid_indices = attention_mask.sum(dim=-1)
            seq_length = attention_mask.size(-1)
            eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) == last_valid_indices
        eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
        attentions = attentions * eos_mask[:, None, None, :, :]
        attentions = attentions[..., :-1, :-1]

    # features: batch x channels x input_ids x input_ids (symmetric)
    batch_size, layers, heads, seqlen, _ = attentions.size()
    attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
    attentions = attentions.to(self.decoder.proj.weight.device)
    attentions = average_product_correct(symmetrize(attentions))
    attentions = attentions.permute(0, 2, 3, 1).squeeze(3)

    return super().forward(attentions, labels, **kwargs)

ContactLogitsHead

Bases: PredictionHead

Head for tasks in contact-level.

Performs symmetrization, and average product correct.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/module/heads/contact.py
Python
@HeadRegistry.contact.register("logits")
class ContactLogitsHead(PredictionHead):
    r"""
    Head for tasks in contact-level.

    Performs symmetrization, and average product correct.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    requires_attention: bool = False

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        num_layers = self.config.get("num_layers", 16)
        num_channels = self.config.get("num_channels", self.config.hidden_size // 10)  # type: ignore[operator]
        block = self.config.get("block", "auto")
        self.decoder = ResNet(
            num_layers=num_layers,
            hidden_size=self.config.hidden_size,  # type: ignore[arg-type]
            block=block,
            num_channels=num_channels,
            num_labels=self.num_labels,
        )
        if head_config is not None and head_config.output_name is not None:
            self.output_name = head_config.output_name

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Mapping | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the ContactPredictionHead.

        Args:
            outputs: The outputs of the model.
            attention_mask: The attention mask for the inputs.
            input_ids: The input ids for the inputs.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

        if attention_mask is None:
            attention_mask = self._get_attention_mask(input_ids)
        output = output * attention_mask.unsqueeze(-1)
        output, _, _ = self._remove_special_tokens(output, attention_mask, input_ids)

        # make symmetric contact map
        contact_map = output.unsqueeze(1) * output.unsqueeze(2)

        return super().forward(contact_map, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Mapping | Tuple[Tensor, ...], attention_mask: Tensor | None = None, input_ids: NestedTensor | Tensor | None = None, labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the ContactPredictionHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Mapping | Tuple[Tensor, ...]

The outputs of the model.

required
attention_mask
Tensor | None

The attention mask for the inputs.

None
input_ids
NestedTensor | Tensor | None

The input ids for the inputs.

None
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/module/heads/contact.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Mapping | Tuple[Tensor, ...],
    attention_mask: Tensor | None = None,
    input_ids: NestedTensor | Tensor | None = None,
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the ContactPredictionHead.

    Args:
        outputs: The outputs of the model.
        attention_mask: The attention mask for the inputs.
        input_ids: The input ids for the inputs.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[0]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")

    if attention_mask is None:
        attention_mask = self._get_attention_mask(input_ids)
    output = output * attention_mask.unsqueeze(-1)
    output, _, _ = self._remove_special_tokens(output, attention_mask, input_ids)

    # make symmetric contact map
    contact_map = output.unsqueeze(1) * output.unsqueeze(2)

    return super().forward(contact_map, labels, **kwargs)

symmetrize

Python
symmetrize(x)

Make layer symmetric in final two dimensions, used for contact prediction.

Source code in multimolecule/module/heads/contact.py
Python
def symmetrize(x):
    "Make layer symmetric in final two dimensions, used for contact prediction."
    return x + x.transpose(-1, -2)

average_product_correct

Python
average_product_correct(x)

Perform average product correct, used for contact prediction.

Source code in multimolecule/module/heads/contact.py
Python
def average_product_correct(x):
    "Perform average product correct, used for contact prediction."
    a1 = x.sum(-1, keepdims=True)
    a2 = x.sum(-2, keepdims=True)
    a12 = x.sum((-1, -2), keepdims=True)

    avg = a1 * a2
    avg.div_(a12)  # in-place to reduce memory
    normalized = x - avg
    return normalized

multimolecule.module.heads.pretrain

MaskedLMHead

Bases: Module

Head for masked language modeling.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

MaskedLMHeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/module/heads/pretrain.py
Python
@HeadRegistry.register("masked_lm")
class MaskedLMHead(nn.Module):
    r"""
    Head for masked language modeling.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    def __init__(
        self, config: PreTrainedConfig, weight: Tensor | None = None, head_config: MaskedLMHeadConfig | None = None
    ):
        super().__init__()
        if head_config is None:
            head_config = (config.lm_head if hasattr(config, "lm_head") else config.head) or MaskedLMHeadConfig()
        self.config: MaskedLMHeadConfig = head_config
        if self.config.hidden_size is None:
            self.config.hidden_size = config.hidden_size
        self.num_labels = config.vocab_size
        self.dropout = nn.Dropout(self.config.dropout)
        self.transform = HeadTransformRegistryHF.build(self.config)
        self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=False)
        if weight is not None:
            self.decoder.weight = weight
        if self.config.bias:
            self.bias = nn.Parameter(torch.zeros(self.num_labels))
            self.decoder.bias = self.bias
        self.activation = ACT2FN[self.config.act] if self.config.act is not None else None
        if head_config is not None and head_config.output_name is not None:
            self.output_name = head_config.output_name

    def forward(
        self, outputs: ModelOutput | Tuple[Tensor, ...], labels: Tensor | None = None, output_name: str | None = None
    ) -> HeadOutput:
        r"""
        Forward pass of the MaskedLMHead.

        Args:
            outputs: The outputs of the model.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
        output = self.dropout(output)
        output = self.transform(output)
        output = self.decoder(output)
        if self.activation is not None:
            output = self.activation(output)
        if labels is not None:
            if isinstance(labels, NestedTensor):
                if isinstance(output, Tensor):
                    output = labels.nested_like(output, strict=False)
                return HeadOutput(output, F.cross_entropy(output.concat, labels.concat))
            return HeadOutput(output, F.cross_entropy(output.view(-1, self.num_labels), labels.view(-1)))
        return HeadOutput(output)

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Tuple[Tensor, ...], labels: Tensor | None = None, output_name: str | None = None) -> HeadOutput

Forward pass of the MaskedLMHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Tuple[Tensor, ...]

The outputs of the model.

required
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/module/heads/pretrain.py
Python
def forward(
    self, outputs: ModelOutput | Tuple[Tensor, ...], labels: Tensor | None = None, output_name: str | None = None
) -> HeadOutput:
    r"""
    Forward pass of the MaskedLMHead.

    Args:
        outputs: The outputs of the model.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[0]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
    output = self.dropout(output)
    output = self.transform(output)
    output = self.decoder(output)
    if self.activation is not None:
        output = self.activation(output)
    if labels is not None:
        if isinstance(labels, NestedTensor):
            if isinstance(output, Tensor):
                output = labels.nested_like(output, strict=False)
            return HeadOutput(output, F.cross_entropy(output.concat, labels.concat))
        return HeadOutput(output, F.cross_entropy(output.view(-1, self.num_labels), labels.view(-1)))
    return HeadOutput(output)

multimolecule.module.heads.generic

PredictionHead

Bases: Module

Head for all-level of tasks.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/module/heads/generic.py
Python
class PredictionHead(nn.Module):
    r"""
    Head for all-level of tasks.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    num_labels: int
    requires_attention: bool = False

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__()
        if head_config is None:
            head_config = config.head or HeadConfig(num_labels=config.num_labels)
        elif head_config.num_labels is None:
            head_config.num_labels = config.num_labels
        self.config = head_config
        if self.config.hidden_size is None:
            self.config.hidden_size = config.hidden_size
        if self.config.problem_type is None:
            self.config.problem_type = config.problem_type
        self.bos_token_id = config.bos_token_id
        self.eos_token_id = config.eos_token_id
        self.pad_token_id = config.pad_token_id
        self.num_labels = self.config.num_labels  # type: ignore[assignment]
        self.dropout = nn.Dropout(self.config.dropout)
        self.transform = HeadTransformRegistryHF.build(self.config)
        self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias)
        self.activation = ACT2FN[self.config.act] if self.config.act is not None else None
        self.criterion = CriterionRegistry.build(self.config)

    def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput:
        r"""
        Forward pass of the PredictionHead.

        Args:
            embeddings: The embeddings to be passed through the head.
            labels: The labels for the head.
        """
        if kwargs:
            warn(
                f"The following arguments are not applicable to {self.__class__.__name__}"
                f"and will be ignored: {kwargs.keys()}"
            )
        output = self.dropout(embeddings)
        output = self.transform(output)
        output = self.decoder(output)
        if self.activation is not None:
            output = self.activation(output)
        if labels is not None:
            if isinstance(labels, NestedTensor):
                if isinstance(output, Tensor):
                    output = labels.nested_like(output, strict=False)
                return HeadOutput(output, self.criterion(output.concat, labels.concat))
            return HeadOutput(output, self.criterion(output, labels))
        return HeadOutput(output)

    def _get_attention_mask(self, input_ids: NestedTensor | Tensor) -> Tensor:
        if isinstance(input_ids, NestedTensor):
            return input_ids.mask
        if input_ids is None:
            raise ValueError(
                f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work."
            )
        if self.pad_token_id is None:
            raise ValueError(
                f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}."
            )
        return input_ids.ne(self.pad_token_id)

    def _remove_special_tokens(
        self, output: Tensor, attention_mask: Tensor, input_ids: Tensor | None
    ) -> Tuple[Tensor, Tensor, Tensor]:
        # remove cls token embeddings
        if self.bos_token_id is not None:
            output = output[..., 1:, :]
            attention_mask = attention_mask[..., 1:]
            if input_ids is not None:
                input_ids = input_ids[..., 1:]
        # remove eos token embeddings
        if self.eos_token_id is not None:
            if input_ids is not None:
                eos_mask = input_ids.ne(self.eos_token_id).to(output)
                input_ids = input_ids[..., :-1]
            else:
                last_valid_indices = attention_mask.sum(dim=-1)
                seq_length = attention_mask.size(-1)
                eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1)
            output = output * eos_mask[:, :, None]
            output = output[..., :-1, :]
            attention_mask = attention_mask[..., 1:]
        return output, attention_mask, input_ids

forward

Python
forward(embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput

Forward pass of the PredictionHead.

Parameters:

Name Type Description Default
embeddings
Tensor

The embeddings to be passed through the head.

required
labels
Tensor | None

The labels for the head.

required
Source code in multimolecule/module/heads/generic.py
Python
def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput:
    r"""
    Forward pass of the PredictionHead.

    Args:
        embeddings: The embeddings to be passed through the head.
        labels: The labels for the head.
    """
    if kwargs:
        warn(
            f"The following arguments are not applicable to {self.__class__.__name__}"
            f"and will be ignored: {kwargs.keys()}"
        )
    output = self.dropout(embeddings)
    output = self.transform(output)
    output = self.decoder(output)
    if self.activation is not None:
        output = self.activation(output)
    if labels is not None:
        if isinstance(labels, NestedTensor):
            if isinstance(output, Tensor):
                output = labels.nested_like(output, strict=False)
            return HeadOutput(output, self.criterion(output.concat, labels.concat))
        return HeadOutput(output, self.criterion(output, labels))
    return HeadOutput(output)

multimolecule.module.heads.output

HeadOutput dataclass

Bases: ModelOutput

Output of a prediction head.

Parameters:

Name Type Description Default

logits

FloatTensor

The prediction logits from the head.

required

loss

FloatTensor | None

The loss from the head. Defaults to None.

None
Source code in multimolecule/module/heads/output.py
Python
@dataclass
class HeadOutput(ModelOutput):
    r"""
    Output of a prediction head.

    Args:
        logits: The prediction logits from the head.
        loss: The loss from the head.
            Defaults to None.
    """

    logits: FloatTensor
    loss: FloatTensor | None = None