Skip to content

heads

heads provide a collection of pre-defined prediction heads.

heads take in either a ModelOutupt, a dict, or a tuple as input. It automatically looks for the model output required for prediction and processes it accordingly.

Some prediction heads may require additional information, such as the attention_mask or the input_ids, like ContactPredictionHead. These additional arguments can be passed in as arguments/keyword arguments.

Note that heads use the same ModelOutupt conventions as the 🤗 Transformers. If the model output is a tuple, we consider the first element as the pooler_output, the second element as the last_hidden_state, and the last element as the attention_map. It is the user’s responsibility to ensure that the model output is correctly formatted.

If the model output is a ModelOutupt or a dict, the heads will look for the HeadConfig.output_name from the model output. You can specify the output_name in the HeadConfig to ensure that the heads can correctly locate the required tensor.

multimolecule.module.heads.config

HeadConfig dataclass

Bases: BaseHeadConfig

Configuration class for a prediction head.

Parameters:

Name Type Description Default

num_labels

int

Number of labels to use in the last layer added to the model, typically for a classification task.

Head should look for Config.num_labels if is None.

None

problem_type

str

Problem type for XxxForYyyPrediction models. Can be one of "regression", "single_label_classification" or "multi_label_classification".

Head should look for Config.problem_type if is None.

None

hidden_size

int | None

Dimensionality of the encoder layers and the pooler layer.

Head should look for Config.hidden_size if is None.

None

dropout

float

The dropout ratio for the hidden states.

0.0

transform

str | None

The transform operation applied to hidden states.

None

transform_act

str | None

The activation function of transform applied to hidden states.

'gelu'

bias

bool

Whether to apply bias to the final prediction layer.

True

act

str | None

The activation function of the final prediction output.

None

layer_norm_eps

float

The epsilon used by the layer normalization layers.

1e-12

output_name

`str`, *optional*

The name of the tensor required in model outputs.

If is None, will use the default output name of the corresponding head.

None
Source code in multimolecule/module/heads/config.py
Python
@dataclass
class HeadConfig(BaseHeadConfig):
    r"""
    Configuration class for a prediction head.

    Args:
        num_labels:
            Number of labels to use in the last layer added to the model, typically for a classification task.

            Head should look for [`Config.num_labels`][multimolecule.PreTrainedConfig] if is `None`.
        problem_type:
            Problem type for `XxxForYyyPrediction` models. Can be one of `"regression"`,
            `"single_label_classification"` or `"multi_label_classification"`.

            Head should look for [`Config.problem_type`][multimolecule.PreTrainedConfig] if is `None`.
        hidden_size:
            Dimensionality of the encoder layers and the pooler layer.

            Head should look for [`Config.hidden_size`][multimolecule.PreTrainedConfig] if is `None`.
        dropout:
            The dropout ratio for the hidden states.
        transform:
            The transform operation applied to hidden states.
        transform_act:
            The activation function of transform applied to hidden states.
        bias:
            Whether to apply bias to the final prediction layer.
        act:
            The activation function of the final prediction output.
        layer_norm_eps:
            The epsilon used by the layer normalization layers.
        output_name (`str`, *optional*):
            The name of the tensor required in model outputs.

            If is `None`, will use the default output name of the corresponding head.
    """

    num_labels: int = None  # type: ignore[assignment]
    problem_type: str = None  # type: ignore[assignment]
    hidden_size: int | None = None
    dropout: float = 0.0
    transform: str | None = None
    transform_act: str | None = "gelu"
    bias: bool = True
    act: str | None = None
    layer_norm_eps: float = 1e-12
    output_name: str | None = None

MaskedLMHeadConfig dataclass

Bases: BaseHeadConfig

Configuration class for a Masked Language Modeling head.

Parameters:

Name Type Description Default

hidden_size

int | None

Dimensionality of the encoder layers and the pooler layer.

Head should look for Config.hidden_size if is None.

None

dropout

float

The dropout ratio for the hidden states.

0.0

transform

str | None

The transform operation applied to hidden states.

'nonlinear'

transform_act

str | None

The activation function of transform applied to hidden states.

'gelu'

bias

bool

Whether to apply bias to the final prediction layer.

True

act

str | None

The activation function of the final prediction output.

None

layer_norm_eps

float

The epsilon used by the layer normalization layers.

1e-12

output_name

`str`, *optional*

The name of the tensor required in model outputs.

If is None, will use the default output name of the corresponding head.

None
Source code in multimolecule/module/heads/config.py
Python
@dataclass
class MaskedLMHeadConfig(BaseHeadConfig):
    r"""
    Configuration class for a Masked Language Modeling head.

    Args:
        hidden_size:
            Dimensionality of the encoder layers and the pooler layer.

            Head should look for [`Config.hidden_size`][multimolecule.PreTrainedConfig] if is `None`.
        dropout:
            The dropout ratio for the hidden states.
        transform:
            The transform operation applied to hidden states.
        transform_act:
            The activation function of transform applied to hidden states.
        bias:
            Whether to apply bias to the final prediction layer.
        act:
            The activation function of the final prediction output.
        layer_norm_eps:
            The epsilon used by the layer normalization layers.
        output_name (`str`, *optional*):
            The name of the tensor required in model outputs.

            If is `None`, will use the default output name of the corresponding head.
    """

    hidden_size: int | None = None
    dropout: float = 0.0
    transform: str | None = "nonlinear"
    transform_act: str | None = "gelu"
    bias: bool = True
    act: str | None = None
    layer_norm_eps: float = 1e-12
    output_name: str | None = None

multimolecule.module.heads.sequence

SequencePredictionHead

Bases: PredictionHead

Head for tasks in sequence-level.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/module/heads/sequence.py
Python
@HeadRegistry.register("sequence")
class SequencePredictionHead(PredictionHead):
    r"""
    Head for tasks in sequence-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "pooler_output"
    r"""The default output to use for the head."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        if head_config is not None and head_config.output_name is not None:
            self.output_name = head_config.output_name

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Tuple[Tensor, ...],
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the SequencePredictionHead.

        Args:
            outputs: The outputs of the model.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[1]
        return super().forward(output, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'pooler_output'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Tuple[Tensor, ...], labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the SequencePredictionHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Tuple[Tensor, ...]

The outputs of the model.

required
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/module/heads/sequence.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Tuple[Tensor, ...],
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the SequencePredictionHead.

    Args:
        outputs: The outputs of the model.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[1]
    return super().forward(output, labels, **kwargs)

multimolecule.module.heads.token

TokenPredictionHead

Bases: PredictionHead

Head for tasks in token-level.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/module/heads/token.py
Python
@HeadRegistry.token.register("single", default=True)
@TokenHeadRegistryHF.register("single", default=True)
class TokenPredictionHead(PredictionHead):
    r"""
    Head for tasks in token-level.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.bos_token_id = config.bos_token_id
        self.eos_token_id = config.eos_token_id
        self.pad_token_id = config.pad_token_id
        if head_config is not None and head_config.output_name is not None:
            self.output_name = head_config.output_name

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the TokenPredictionHead.

        Args:
            outputs: The outputs of the model.
            attention_mask: The attention mask for the inputs.
            input_ids: The input ids for the inputs.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if attention_mask is None:
            if isinstance(input_ids, NestedTensor):
                input_ids, attention_mask = input_ids.tensor, input_ids.mask
            else:
                if input_ids is None:
                    raise ValueError(
                        f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work."
                    )
                if self.pad_token_id is None:
                    raise ValueError(
                        f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}."
                    )
                attention_mask = input_ids.ne(self.pad_token_id)

        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
        output = output * attention_mask.unsqueeze(-1)

        # remove cls token embeddings
        if self.bos_token_id is not None:
            output = output[..., 1:, :]
            # process attention_mask and input_ids to make removal of eos token happy
            attention_mask = attention_mask[..., 1:]
            if input_ids is not None:
                input_ids = input_ids[..., 1:]
        # remove eos token embeddings
        if self.eos_token_id is not None:
            if input_ids is not None:
                eos_mask = input_ids.ne(self.eos_token_id).to(output)
            else:
                last_valid_indices = attention_mask.sum(dim=-1)
                seq_length = attention_mask.size(-1)
                eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1)
            output = output * eos_mask[:, :, None]
            output = output[..., :-1, :]

        return super().forward(output, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Tuple[Tensor, ...], attention_mask: Tensor | None = None, input_ids: NestedTensor | Tensor | None = None, labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the TokenPredictionHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Tuple[Tensor, ...]

The outputs of the model.

required
attention_mask
Tensor | None

The attention mask for the inputs.

None
input_ids
NestedTensor | Tensor | None

The input ids for the inputs.

None
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/module/heads/token.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Tuple[Tensor, ...],
    attention_mask: Tensor | None = None,
    input_ids: NestedTensor | Tensor | None = None,
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the TokenPredictionHead.

    Args:
        outputs: The outputs of the model.
        attention_mask: The attention mask for the inputs.
        input_ids: The input ids for the inputs.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if attention_mask is None:
        if isinstance(input_ids, NestedTensor):
            input_ids, attention_mask = input_ids.tensor, input_ids.mask
        else:
            if input_ids is None:
                raise ValueError(
                    f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work."
                )
            if self.pad_token_id is None:
                raise ValueError(
                    f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}."
                )
            attention_mask = input_ids.ne(self.pad_token_id)

    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[0]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
    output = output * attention_mask.unsqueeze(-1)

    # remove cls token embeddings
    if self.bos_token_id is not None:
        output = output[..., 1:, :]
        # process attention_mask and input_ids to make removal of eos token happy
        attention_mask = attention_mask[..., 1:]
        if input_ids is not None:
            input_ids = input_ids[..., 1:]
    # remove eos token embeddings
    if self.eos_token_id is not None:
        if input_ids is not None:
            eos_mask = input_ids.ne(self.eos_token_id).to(output)
        else:
            last_valid_indices = attention_mask.sum(dim=-1)
            seq_length = attention_mask.size(-1)
            eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1)
        output = output * eos_mask[:, :, None]
        output = output[..., :-1, :]

    return super().forward(output, labels, **kwargs)

TokenKMerHead

Bases: PredictionHead

Head for tasks in token-level with kmer inputs.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/module/heads/token.py
Python
@HeadRegistry.register("token.kmer")
@TokenHeadRegistryHF.register("kmer")
class TokenKMerHead(PredictionHead):
    r"""
    Head for tasks in token-level with kmer inputs.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.nmers = config.nmers
        self.bos_token_id = config.bos_token_id
        self.eos_token_id = config.eos_token_id
        self.pad_token_id = config.pad_token_id
        if head_config is not None and head_config.output_name is not None:
            self.output_name = head_config.output_name
        # Do not pass bos_token_id and eos_token_id to unfold_kmer_embeddings
        # As they will be removed in preprocess
        self.unfold_kmer_embeddings = partial(unfold_kmer_embeddings, nmers=self.nmers)

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the TokenKMerHead.

        Args:
            outputs: The outputs of the model.
            attention_mask: The attention mask for the inputs.
            input_ids: The input ids for the inputs.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if attention_mask is None:
            if isinstance(input_ids, NestedTensor):
                input_ids, attention_mask = input_ids.tensor, input_ids.mask
            else:
                if input_ids is None:
                    raise ValueError(
                        f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work."
                    )
                if self.pad_token_id is None:
                    raise ValueError(
                        f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}."
                    )
                attention_mask = input_ids.ne(self.pad_token_id)

        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
        output = output * attention_mask.unsqueeze(-1)

        # remove cls token embeddings
        if self.bos_token_id is not None:
            output = output[..., 1:, :]
            attention_mask = attention_mask[..., 1:]
            if input_ids is not None:
                input_ids = input_ids[..., 1:]
        # remove eos token embeddings
        if self.eos_token_id is not None:
            if input_ids is not None:
                eos_mask = input_ids.ne(self.eos_token_id).to(output)
                input_ids = input_ids[..., :-1]
            else:
                last_valid_indices = attention_mask.sum(dim=-1)
                seq_length = attention_mask.size(-1)
                eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1)
            output = output * eos_mask[:, :, None]
            output = output[..., :-1, :]
            attention_mask = attention_mask[..., 1:]

        output = self.unfold_kmer_embeddings(output, attention_mask)
        return super().forward(output, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Tuple[Tensor, ...], attention_mask: Tensor | None = None, input_ids: NestedTensor | Tensor | None = None, labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the TokenKMerHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Tuple[Tensor, ...]

The outputs of the model.

required
attention_mask
Tensor | None

The attention mask for the inputs.

None
input_ids
NestedTensor | Tensor | None

The input ids for the inputs.

None
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/module/heads/token.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Tuple[Tensor, ...],
    attention_mask: Tensor | None = None,
    input_ids: NestedTensor | Tensor | None = None,
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the TokenKMerHead.

    Args:
        outputs: The outputs of the model.
        attention_mask: The attention mask for the inputs.
        input_ids: The input ids for the inputs.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if attention_mask is None:
        if isinstance(input_ids, NestedTensor):
            input_ids, attention_mask = input_ids.tensor, input_ids.mask
        else:
            if input_ids is None:
                raise ValueError(
                    f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work."
                )
            if self.pad_token_id is None:
                raise ValueError(
                    f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}."
                )
            attention_mask = input_ids.ne(self.pad_token_id)

    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[0]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
    output = output * attention_mask.unsqueeze(-1)

    # remove cls token embeddings
    if self.bos_token_id is not None:
        output = output[..., 1:, :]
        attention_mask = attention_mask[..., 1:]
        if input_ids is not None:
            input_ids = input_ids[..., 1:]
    # remove eos token embeddings
    if self.eos_token_id is not None:
        if input_ids is not None:
            eos_mask = input_ids.ne(self.eos_token_id).to(output)
            input_ids = input_ids[..., :-1]
        else:
            last_valid_indices = attention_mask.sum(dim=-1)
            seq_length = attention_mask.size(-1)
            eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1)
        output = output * eos_mask[:, :, None]
        output = output[..., :-1, :]
        attention_mask = attention_mask[..., 1:]

    output = self.unfold_kmer_embeddings(output, attention_mask)
    return super().forward(output, labels, **kwargs)

multimolecule.module.heads.contact

ContactPredictionHead

Bases: PredictionHead

Head for tasks in contact-level.

Performs symmetrization, and average product correct.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/module/heads/contact.py
Python
@HeadRegistry.register("contact")
class ContactPredictionHead(PredictionHead):
    r"""
    Head for tasks in contact-level.

    Performs symmetrization, and average product correct.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "attentions"
    r"""The default output to use for the head."""

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__(config, head_config)
        self.bos_token_id = config.bos_token_id
        self.eos_token_id = config.eos_token_id
        self.pad_token_id = config.pad_token_id
        self.decoder = nn.Linear(
            config.num_hidden_layers * config.num_attention_heads, self.num_labels, bias=self.config.bias
        )
        if head_config is not None and head_config.output_name is not None:
            self.output_name = head_config.output_name

    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
        self,
        outputs: ModelOutput | Mapping | Tuple[Tensor, ...],
        attention_mask: Tensor | None = None,
        input_ids: NestedTensor | Tensor | None = None,
        labels: Tensor | None = None,
        output_name: str | None = None,
        **kwargs,
    ) -> HeadOutput:
        r"""
        Forward pass of the ContactPredictionHead.

        Args:
            outputs: The outputs of the model.
            attention_mask: The attention mask for the inputs.
            input_ids: The input ids for the inputs.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if attention_mask is None:
            if isinstance(input_ids, NestedTensor):
                input_ids, attention_mask = input_ids.tensor, input_ids.mask
            else:
                if input_ids is None:
                    raise ValueError(
                        f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work."
                    )
                if self.pad_token_id is None:
                    raise ValueError(
                        f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}."
                    )
                attention_mask = input_ids.ne(self.pad_token_id)

        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[-1]
        attentions = torch.stack(output, 1)

        # In the original model, attentions for padding tokens are completely zeroed out.
        # This makes no difference most of the time because the other tokens won't attend to them,
        # but it does for the contact prediction task, which takes attentions as input,
        # so we have to mimic that here.
        attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
        attentions *= attention_mask[:, None, None, :, :]

        # remove cls token attentions
        if self.bos_token_id is not None:
            attentions = attentions[..., 1:, 1:]
            # process attention_mask and input_ids to make removal of eos token happy
            attention_mask = attention_mask[..., 1:]
            if input_ids is not None:
                input_ids = input_ids[..., 1:]
        # remove eos token attentions
        if self.eos_token_id is not None:
            if input_ids is not None:
                eos_mask = input_ids.ne(self.eos_token_id).to(attentions)
            else:
                last_valid_indices = attention_mask.sum(dim=-1)
                seq_length = attention_mask.size(-1)
                eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) == last_valid_indices
            eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
            attentions *= eos_mask[:, None, None, :, :]
            attentions = attentions[..., :-1, :-1]

        # features: batch x channels x input_ids x input_ids (symmetric)
        batch_size, layers, heads, seqlen, _ = attentions.size()
        attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
        attentions = attentions.to(self.decoder.weight.device)
        attentions = average_product_correct(symmetrize(attentions))
        attentions = attentions.permute(0, 2, 3, 1).squeeze(3)

        return super().forward(attentions, labels, **kwargs)

output_name class-attribute instance-attribute

Python
output_name: str = 'attentions'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Mapping | Tuple[Tensor, ...], attention_mask: Tensor | None = None, input_ids: NestedTensor | Tensor | None = None, labels: Tensor | None = None, output_name: str | None = None, **kwargs) -> HeadOutput

Forward pass of the ContactPredictionHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Mapping | Tuple[Tensor, ...]

The outputs of the model.

required
attention_mask
Tensor | None

The attention mask for the inputs.

None
input_ids
NestedTensor | Tensor | None

The input ids for the inputs.

None
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/module/heads/contact.py
Python
def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
    self,
    outputs: ModelOutput | Mapping | Tuple[Tensor, ...],
    attention_mask: Tensor | None = None,
    input_ids: NestedTensor | Tensor | None = None,
    labels: Tensor | None = None,
    output_name: str | None = None,
    **kwargs,
) -> HeadOutput:
    r"""
    Forward pass of the ContactPredictionHead.

    Args:
        outputs: The outputs of the model.
        attention_mask: The attention mask for the inputs.
        input_ids: The input ids for the inputs.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if attention_mask is None:
        if isinstance(input_ids, NestedTensor):
            input_ids, attention_mask = input_ids.tensor, input_ids.mask
        else:
            if input_ids is None:
                raise ValueError(
                    f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work."
                )
            if self.pad_token_id is None:
                raise ValueError(
                    f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}."
                )
            attention_mask = input_ids.ne(self.pad_token_id)

    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[-1]
    attentions = torch.stack(output, 1)

    # In the original model, attentions for padding tokens are completely zeroed out.
    # This makes no difference most of the time because the other tokens won't attend to them,
    # but it does for the contact prediction task, which takes attentions as input,
    # so we have to mimic that here.
    attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
    attentions *= attention_mask[:, None, None, :, :]

    # remove cls token attentions
    if self.bos_token_id is not None:
        attentions = attentions[..., 1:, 1:]
        # process attention_mask and input_ids to make removal of eos token happy
        attention_mask = attention_mask[..., 1:]
        if input_ids is not None:
            input_ids = input_ids[..., 1:]
    # remove eos token attentions
    if self.eos_token_id is not None:
        if input_ids is not None:
            eos_mask = input_ids.ne(self.eos_token_id).to(attentions)
        else:
            last_valid_indices = attention_mask.sum(dim=-1)
            seq_length = attention_mask.size(-1)
            eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) == last_valid_indices
        eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
        attentions *= eos_mask[:, None, None, :, :]
        attentions = attentions[..., :-1, :-1]

    # features: batch x channels x input_ids x input_ids (symmetric)
    batch_size, layers, heads, seqlen, _ = attentions.size()
    attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
    attentions = attentions.to(self.decoder.weight.device)
    attentions = average_product_correct(symmetrize(attentions))
    attentions = attentions.permute(0, 2, 3, 1).squeeze(3)

    return super().forward(attentions, labels, **kwargs)

multimolecule.module.heads.pretrain

MaskedLMHead

Bases: Module

Head for masked language modeling.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

MaskedLMHeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/module/heads/pretrain.py
Python
@HeadRegistry.register("masked_lm")
class MaskedLMHead(nn.Module):
    r"""
    Head for masked language modeling.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    output_name: str = "last_hidden_state"
    r"""The default output to use for the head."""

    def __init__(
        self, config: PreTrainedConfig, weight: Tensor | None = None, head_config: MaskedLMHeadConfig | None = None
    ):
        super().__init__()
        if head_config is None:
            head_config = config.lm_head if hasattr(config, "lm_head") else config.head  # type: ignore[assignment]
        self.config: MaskedLMHeadConfig = head_config  # type: ignore[assignment]
        if self.config.hidden_size is None:
            self.config.hidden_size = config.hidden_size
        self.num_labels = config.vocab_size
        self.dropout = nn.Dropout(self.config.dropout)
        self.transform = HeadTransformRegistryHF.build(self.config)
        self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=False)
        if weight is not None:
            self.decoder.weight = weight
        if self.config.bias:
            self.bias = nn.Parameter(torch.zeros(self.num_labels))
            self.decoder.bias = self.bias
        self.activation = ACT2FN[self.config.act] if self.config.act is not None else None
        if head_config is not None and head_config.output_name is not None:
            self.output_name = head_config.output_name

    def forward(
        self, outputs: ModelOutput | Tuple[Tensor, ...], labels: Tensor | None = None, output_name: str | None = None
    ) -> HeadOutput:
        r"""
        Forward pass of the MaskedLMHead.

        Args:
            outputs: The outputs of the model.
            labels: The labels for the head.
            output_name: The name of the output to use.
                Defaults to `self.output_name`.
        """
        if isinstance(outputs, (Mapping, ModelOutput)):
            output = outputs[output_name or self.output_name]
        elif isinstance(outputs, tuple):
            output = outputs[0]
        else:
            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
        output = self.dropout(output)
        output = self.transform(output)
        output = self.decoder(output)
        if self.activation is not None:
            output = self.activation(output)
        if labels is not None:
            if isinstance(labels, NestedTensor):
                if isinstance(output, Tensor):
                    output = labels.nested_like(output, strict=False)
                return HeadOutput(output, F.cross_entropy(torch.cat(output.storage()), torch.cat(labels.storage())))
            return HeadOutput(output, F.cross_entropy(output.view(-1, self.num_labels), labels.view(-1)))
        return HeadOutput(output)

output_name class-attribute instance-attribute

Python
output_name: str = 'last_hidden_state'

The default output to use for the head.

forward

Python
forward(outputs: ModelOutput | Tuple[Tensor, ...], labels: Tensor | None = None, output_name: str | None = None) -> HeadOutput

Forward pass of the MaskedLMHead.

Parameters:

Name Type Description Default
outputs
ModelOutput | Tuple[Tensor, ...]

The outputs of the model.

required
labels
Tensor | None

The labels for the head.

None
output_name
str | None

The name of the output to use. Defaults to self.output_name.

None
Source code in multimolecule/module/heads/pretrain.py
Python
def forward(
    self, outputs: ModelOutput | Tuple[Tensor, ...], labels: Tensor | None = None, output_name: str | None = None
) -> HeadOutput:
    r"""
    Forward pass of the MaskedLMHead.

    Args:
        outputs: The outputs of the model.
        labels: The labels for the head.
        output_name: The name of the output to use.
            Defaults to `self.output_name`.
    """
    if isinstance(outputs, (Mapping, ModelOutput)):
        output = outputs[output_name or self.output_name]
    elif isinstance(outputs, tuple):
        output = outputs[0]
    else:
        raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
    output = self.dropout(output)
    output = self.transform(output)
    output = self.decoder(output)
    if self.activation is not None:
        output = self.activation(output)
    if labels is not None:
        if isinstance(labels, NestedTensor):
            if isinstance(output, Tensor):
                output = labels.nested_like(output, strict=False)
            return HeadOutput(output, F.cross_entropy(torch.cat(output.storage()), torch.cat(labels.storage())))
        return HeadOutput(output, F.cross_entropy(output.view(-1, self.num_labels), labels.view(-1)))
    return HeadOutput(output)

multimolecule.module.heads.generic

PredictionHead

Bases: Module

Head for all-level of tasks.

Parameters:

Name Type Description Default

config

PreTrainedConfig

The configuration object for the model.

required

head_config

HeadConfig | None

The configuration object for the head. If None, will use configuration from the config.

None
Source code in multimolecule/module/heads/generic.py
Python
class PredictionHead(nn.Module):
    r"""
    Head for all-level of tasks.

    Args:
        config: The configuration object for the model.
        head_config: The configuration object for the head.
            If None, will use configuration from the `config`.
    """

    num_labels: int

    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
        super().__init__()
        if head_config is None:
            head_config = config.head
        self.config = head_config
        if self.config.hidden_size is None:
            self.config.hidden_size = config.hidden_size
        if self.config.num_labels is None:
            self.config.num_labels = config.num_labels
        if self.config.problem_type is None:
            self.config.problem_type = config.problem_type
        self.num_labels = self.config.num_labels
        self.dropout = nn.Dropout(self.config.dropout)
        self.transform = HeadTransformRegistryHF.build(self.config)
        self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=self.config.bias)
        self.activation = ACT2FN[self.config.act] if self.config.act is not None else None
        self.criterion = Criterion(self.config)

    def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput:
        r"""
        Forward pass of the PredictionHead.

        Args:
            embeddings: The embeddings to be passed through the head.
            labels: The labels for the head.
        """
        if kwargs:
            warn(
                f"The following arguments are not applicable to {self.__class__.__name__}"
                f"and will be ignored: {kwargs.keys()}"
            )
        output = self.dropout(embeddings)
        output = self.transform(output)
        output = self.decoder(output)
        if self.activation is not None:
            output = self.activation(output)
        if labels is not None:
            if isinstance(labels, NestedTensor):
                if isinstance(output, Tensor):
                    output = labels.nested_like(output, strict=False)
                return HeadOutput(output, self.criterion(torch.cat(output.storage()), torch.cat(labels.storage())))
            return HeadOutput(output, self.criterion(output, labels))
        return HeadOutput(output)

forward

Python
forward(embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput

Forward pass of the PredictionHead.

Parameters:

Name Type Description Default
embeddings
Tensor

The embeddings to be passed through the head.

required
labels
Tensor | None

The labels for the head.

required
Source code in multimolecule/module/heads/generic.py
Python
def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput:
    r"""
    Forward pass of the PredictionHead.

    Args:
        embeddings: The embeddings to be passed through the head.
        labels: The labels for the head.
    """
    if kwargs:
        warn(
            f"The following arguments are not applicable to {self.__class__.__name__}"
            f"and will be ignored: {kwargs.keys()}"
        )
    output = self.dropout(embeddings)
    output = self.transform(output)
    output = self.decoder(output)
    if self.activation is not None:
        output = self.activation(output)
    if labels is not None:
        if isinstance(labels, NestedTensor):
            if isinstance(output, Tensor):
                output = labels.nested_like(output, strict=False)
            return HeadOutput(output, self.criterion(torch.cat(output.storage()), torch.cat(labels.storage())))
        return HeadOutput(output, self.criterion(output, labels))
    return HeadOutput(output)

multimolecule.module.heads.output

HeadOutput dataclass

Bases: ModelOutput

Output of a prediction head.

Parameters:

Name Type Description Default

logits

FloatTensor

The prediction logits from the head.

required

loss

FloatTensor | None

The loss from the head. Defaults to None.

None
Source code in multimolecule/module/heads/output.py
Python
@dataclass
class HeadOutput(ModelOutput):
    r"""
    Output of a prediction head.

    Args:
        logits: The prediction logits from the head.
        loss: The loss from the head.
            Defaults to None.
    """

    logits: FloatTensor
    loss: FloatTensor | None = None