Skip to content

Summary

multimolecule.utils.summary.calculate_flops

Python
calculate_flops(
    model: Module,
    *args,
    module_ops: Mapping[type, Callable] | None = None,
    excluded_modules: Type | Tuple[Type, ...] | None = None,
    format_spec: str | None = None,
    **kwargs
) -> int | str

Calculate the number of FLOPs (floating point operations) in a PyTorch model.

This performs a single forward pass with hooks attached to count operations per module. For HuggingFace transformer models, attention matrix operations (Q@K^T and attn@V) that use raw torch.matmul are automatically estimated from the model config.

Parameters:

Name Type Description Default

model

Module

The model for which to calculate the FLOPs.

required

*args

Positional arguments forwarded to model.forward().

()

module_ops

Mapping[type, Callable]

Custom per-module-type hooks. Each callable has signature (module, input, output) -> (flops, macs). Overrides built-in hooks for the same type.

None

excluded_modules

type | tuple[type, ...]

Module types to exclude from the calculation.

None

format_spec

str

A format specifier to format the output. If is None, the number of FLOPs is returned as an int. If is not None, the number of FLOPs is returned as a str formatted according to the format specifier. Default to None.

None

**kwargs

Keyword arguments forwarded to model.forward().

{}

Returns:

Type Description
int | str

int | str: The number of FLOPs in the model.

Examples:

Python Console Session
1
2
3
4
5
6
>>> model = nn.Linear(768, 3072)
>>> input = torch.randn(1, 128, 768)
>>> calculate_flops(model, input)
604372992
>>> calculate_flops(model, input, format_spec=",")
'604,372,992'
Source code in multimolecule/utils/summary/flops.py
Python
def calculate_flops(
    model: nn.Module,
    *args,
    module_ops: Mapping[type, Callable] | None = None,
    excluded_modules: Type | Tuple[Type, ...] | None = None,
    format_spec: str | None = None,
    **kwargs,
) -> int | str:
    """
    Calculate the number of FLOPs (floating point operations) in a PyTorch model.

    This performs a single forward pass with hooks attached to count operations
    per module. For HuggingFace transformer models, attention matrix operations
    (Q@K^T and attn@V) that use raw torch.matmul are automatically estimated
    from the model config.

    Args:
        model (torch.nn.Module): The model for which to calculate the FLOPs.
        *args: Positional arguments forwarded to ``model.forward()``.
        module_ops (Mapping[type, Callable], optional): Custom per-module-type
            hooks. Each callable has signature ``(module, input, output) -> (flops, macs)``.
            Overrides built-in hooks for the same type.
        excluded_modules (type | tuple[type, ...], optional): Module types to
            exclude from the calculation.
        format_spec (str, optional): A format specifier to format the output.
            If is None, the number of FLOPs is returned as an int.
            If is not None, the number of FLOPs is returned as a str formatted
            according to the format specifier. Default to None.
        **kwargs: Keyword arguments forwarded to ``model.forward()``.

    Returns:
        int | str: The number of FLOPs in the model.

    Examples:
        >>> model = nn.Linear(768, 3072)
        >>> input = torch.randn(1, 128, 768)
        >>> calculate_flops(model, input)
        604372992
        >>> calculate_flops(model, input, format_spec=",")
        '604,372,992'
    """

    flops, _ = _calculate_ops(model, *args, module_ops=module_ops, excluded_modules=excluded_modules, **kwargs)
    if format_spec is not None:
        return format(flops, format_spec)
    return flops

multimolecule.utils.summary.calculate_macs

Python
calculate_macs(
    model: Module,
    *args,
    module_ops: Mapping[type, Callable] | None = None,
    excluded_modules: Type | Tuple[Type, ...] | None = None,
    format_spec: str | None = None,
    **kwargs
) -> int | str

Calculate the number of MACs (multiply-accumulate operations) in a PyTorch model.

This performs a single forward pass with hooks attached to count operations per module. For HuggingFace transformer models, attention matrix operations (Q@K^T and attn@V) that use raw torch.matmul are automatically estimated from the model config.

Parameters:

Name Type Description Default

model

Module

The model for which to calculate the MACs.

required

*args

Positional arguments forwarded to model.forward().

()

module_ops

Mapping[type, Callable]

Custom per-module-type hooks. Each callable has signature (module, input, output) -> (flops, macs). Overrides built-in hooks for the same type.

None

excluded_modules

type | tuple[type, ...]

Module types to exclude from the calculation.

None

format_spec

str

A format specifier to format the output. If is None, the number of MACs is returned as an int. If is not None, the number of MACs is returned as a str formatted according to the format specifier. Default to None.

None

**kwargs

Keyword arguments forwarded to model.forward().

{}

Returns:

Type Description
int | str

int | str: The number of MACs in the model.

Examples:

Python Console Session
1
2
3
4
5
6
>>> model = nn.Linear(768, 3072)
>>> input = torch.randn(1, 128, 768)
>>> calculate_macs(model, input)
301989888
>>> calculate_macs(model, input, format_spec=",")
'301,989,888'
Source code in multimolecule/utils/summary/flops.py
Python
def calculate_macs(
    model: nn.Module,
    *args,
    module_ops: Mapping[type, Callable] | None = None,
    excluded_modules: Type | Tuple[Type, ...] | None = None,
    format_spec: str | None = None,
    **kwargs,
) -> int | str:
    """
    Calculate the number of MACs (multiply-accumulate operations) in a PyTorch model.

    This performs a single forward pass with hooks attached to count operations
    per module. For HuggingFace transformer models, attention matrix operations
    (Q@K^T and attn@V) that use raw torch.matmul are automatically estimated
    from the model config.

    Args:
        model (torch.nn.Module): The model for which to calculate the MACs.
        *args: Positional arguments forwarded to ``model.forward()``.
        module_ops (Mapping[type, Callable], optional): Custom per-module-type
            hooks. Each callable has signature ``(module, input, output) -> (flops, macs)``.
            Overrides built-in hooks for the same type.
        excluded_modules (type | tuple[type, ...], optional): Module types to
            exclude from the calculation.
        format_spec (str, optional): A format specifier to format the output.
            If is None, the number of MACs is returned as an int.
            If is not None, the number of MACs is returned as a str formatted
            according to the format specifier. Default to None.
        **kwargs: Keyword arguments forwarded to ``model.forward()``.

    Returns:
        int | str: The number of MACs in the model.

    Examples:
        >>> model = nn.Linear(768, 3072)
        >>> input = torch.randn(1, 128, 768)
        >>> calculate_macs(model, input)
        301989888
        >>> calculate_macs(model, input, format_spec=",")
        '301,989,888'
    """

    _, macs = _calculate_ops(model, *args, module_ops=module_ops, excluded_modules=excluded_modules, **kwargs)
    if format_spec is not None:
        return format(macs, format_spec)
    return macs

multimolecule.utils.summary.count_parameters

Python
count_parameters(
    model: Module,
    trainable: bool = True,
    unique: bool = True,
    format_spec: str | None = None,
) -> int | str

Count the number of parameters in a PyTorch model, optionally only counting those that require gradients (i.e., are trainable) and/or are unique.

Parameters:

Name Type Description Default

model

Module

The model for which to count the parameters.

required

trainable

bool

Whether to count only parameters that require gradients. Default to True.

True

unique

bool

Whether to count only unique parameters. Default to True.

True

format_spec

str

A format specifier to format the output. If is None, the number of parameters is returned as an int. If is not None, the number of parameters is returned as a str formatted according to the format specifier. Default to None.

None

Returns:

Type Description
int | str

int | str: The number of parameters in the model, according to the criteria specified by trainable and unique.

Examples:

Python Console Session
>>> from torchvision import models
Python Console Session
1
2
3
4
5
>>> model = models.alexnet()
>>> count_parameters(model)
61100840
>>> count_parameters(model, format_spec=",")
'61,100,840'
Python Console Session
1
2
3
4
5
>>> model = models.vgg16()
>>> count_parameters(model)
138357544
>>> count_parameters(model, format_spec=",")
'138,357,544'
Python Console Session
1
2
3
4
5
>>> model = models.resnet50()
>>> count_parameters(model)
25557032
>>> count_parameters(model, format_spec=",")
'25,557,032'
Source code in multimolecule/utils/summary/parameters.py
Python
def count_parameters(
    model: nn.Module, trainable: bool = True, unique: bool = True, format_spec: str | None = None
) -> int | str:
    """
    Count the number of parameters in a PyTorch model, optionally only counting
    those that require gradients (i.e., are trainable) and/or are unique.

    Args:
        model (torch.nn.Module): The model for which to count the parameters.
        trainable (bool, optional): Whether to count only parameters that require gradients.
            Default to True.
        unique (bool, optional): Whether to count only unique parameters.
            Default to True.
        format_spec (str, optional): A format specifier to format the output.
            If is None, the number of parameters is returned as an int.
            If is not None, the number of parameters is returned as a str formatted according to the format specifier.
            Default to None.

    Returns:
        int | str: The number of parameters in the model, according to the criteria specified by `trainable` and
            `unique`.

    Examples:
        >>> from torchvision import models

        >>> model = models.alexnet()
        >>> count_parameters(model)
        61100840
        >>> count_parameters(model, format_spec=",")
        '61,100,840'

        >>> model = models.vgg16()
        >>> count_parameters(model)
        138357544
        >>> count_parameters(model, format_spec=",")
        '138,357,544'

        >>> model = models.resnet50()
        >>> count_parameters(model)
        25557032
        >>> count_parameters(model, format_spec=",")
        '25,557,032'
    """

    if unique:
        unique_parameters = set()
        num_parameters = 0
        for p in model.parameters():
            if p.data_ptr() in unique_parameters:
                continue
            if (trainable and p.requires_grad) or not trainable:
                unique_parameters.add(p.data_ptr())
                num_parameters += p.numel()
    elif trainable:
        num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        num_parameters = sum(p.numel() for p in model.parameters())
    if format_spec is not None:
        return format(num_parameters, format_spec)
    return num_parameters