BPNet
Base-resolution convolutional neural network for predicting transcription-factor binding profiles from DNA sequence.
Disclaimer
This is an UNOFFICIAL implementation of Base-resolution models of transcription-factor binding reveal soft motif syntax by Žiga Avsec, Melanie Weilert, et al.
The OFFICIAL repository of BPNet is at kundajelab/bpnet.
Tip
The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.
The team releasing BPNet did not write this model card for this model so this model card has been written by the MultiMolecule team.
Model Details
BPNet is a convolutional neural network (CNN) trained to predict base-resolution transcription-factor binding signal (ChIP-nexus) from primary DNA sequence. It uses a convolutional motif stem followed by a stack of dilated residual convolutions that aggregate ~1 kb of genomic context. The output is factorized into profile and count branches, and the usable base-resolution prediction is reconstructed by BpNetForProfilePrediction.postprocess. Please refer to the Training Details section for more information on the training process.
Model Specification
| Num Layers |
Hidden Size |
Num Parameters (M) |
FLOPs (G) |
MACs (G) |
| 10 |
64 |
0.13 |
0.24 |
0.12 |
Links
- Code: multimolecule.bpnet
- Data: BPNet manuscript data
- Paper: Base-resolution models of transcription-factor binding reveal soft motif syntax
- Developed by: Žiga Avsec, Melanie Weilert, Avanti Shrikumar, Sabrina Krueger, Amr Alexandari, Khyati Dalal, Robin Fropf, Charles McAnany, Julien Gagneur, Anshul Kundaje, Julia Zeitlinger
- Model type: 1D dilated CNN with factorized profile-and-count heads for base-resolution transcription-factor binding prediction
- Original Repository: kundajelab/bpnet
Usage
The model file depends on the multimolecule library. You can install it using pip:
| Bash |
|---|
| pip install multimolecule
|
Direct Use
Transcription-Factor Binding Profile Prediction
You can use this model directly to predict transcription-factor binding profiles of a DNA sequence:
| Python |
|---|
| >>> from multimolecule import DnaTokenizer, BpNetForProfilePrediction
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/bpnet")
>>> model = BpNetForProfilePrediction.from_pretrained("multimolecule/bpnet")
>>> output = model(**tokenizer("ACGTNACGTN", return_tensors="pt"))
>>> output.keys()
odict_keys(['profile_logits', 'count_logits'])
>>> output["profile_logits"].shape
torch.Size([1, 10, 8])
>>> output["count_logits"].shape
torch.Size([1, 8])
>>> track = model.postprocess(output)
>>> track.shape
torch.Size([1, 10, 8])
|
The recombined track is the usable base-resolution prediction. The last dimension stacks num_tasks (Oct4, Sox2, Nanog, Klf4) by num_strands (forward, reverse).
Interface
- Input length: 1000 bp DNA window
- Output: factorized
(profile_logits, count_logits); recombine the usable base-resolution track via BpNetForProfilePrediction.postprocess
- Output shape:
(batch_size, profile_length, num_tasks × num_strands); Oct4 / Sox2 / Nanog / Klf4 × forward / reverse = 8 channels
Training Details
BPNet was trained to predict the base-resolution ChIP-nexus binding profiles of the pluripotency transcription factors Oct4, Sox2, Nanog and Klf4 in mouse embryonic stem cells.
Training Data
The published BPNet-OSKN model was trained on ChIP-nexus profiles for Oct4, Sox2, Nanog and Klf4, using 1 kb genomic windows centered on detected binding peaks. The training regions and trained model files are distributed as part of the BPNet manuscript data.
Training Procedure
Pre-training
The model was trained with a composite loss: a multinomial negative log-likelihood on the per-position profile shape plus a mean-squared-error regression on the log total counts.
Citation
| BibTeX |
|---|
| @article{avsec2021baseresolution,
author = {Avsec, {\v{Z}}iga and Weilert, Melanie and Shrikumar, Avanti and Krueger, Sabrina and Alexandari, Amr and Dalal, Khyati and Fropf, Robin and McAnany, Charles and Gagneur, Julien and Kundaje, Anshul and Zeitlinger, Julia},
title = {Base-resolution models of transcription-factor binding reveal soft motif syntax},
journal = {Nature Genetics},
volume = 53,
number = 3,
pages = {354--366},
year = 2021,
publisher = {Nature Publishing Group},
doi = {10.1038/s41588-021-00782-6}
}
|
Note
The artifacts distributed in this repository are part of the MultiMolecule project.
If MultiMolecule supports your research, please cite the MultiMolecule project as follows:
| BibTeX |
|---|
| @software{chen_2024_12638419,
author = {Chen, Zhiyuan and Zhu, Sophia Y.},
title = {MultiMolecule},
doi = {10.5281/zenodo.12638419},
publisher = {Zenodo},
url = {https://doi.org/10.5281/zenodo.12638419},
year = 2024,
month = may,
day = 4
}
|
Please use GitHub issues of MultiMolecule for any questions or comments on the model card.
Please contact the authors of the BPNet paper for questions or comments on the paper/model.
License
This model implementation is licensed under the GNU Affero General Public License.
For additional terms and clarifications, please refer to our License FAQ.
| Text Only |
|---|
| SPDX-License-Identifier: AGPL-3.0-or-later
|
API Reference
BpNetConfig
Bases: PreTrainedConfig
This is the configuration class to store the configuration of a
BpNetModel. It is used to instantiate a BPNet model according to the specified
arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar
configuration to that of the BPNet BPNet-OSKN architecture.
Configuration objects inherit from PreTrainedConfig and can be used to
control the model outputs. Read the documentation from PreTrainedConfig
for more information.
BPNet predicts a single base-resolution signal task whose output is factorized into two terminal branches that
share the dilated-convolution backbone:
- a profile branch producing per-position multinomial logits of shape
(batch_size, sequence_length, num_tasks * num_strands);
- a count branch producing a scalar per task and strand of shape
(batch_size, num_tasks * num_strands).
Parameters:
| Name |
Type |
Description |
Default |
vocab_size
|
int
|
Vocabulary size of the BPNet model. Defines the number of one-hot input channels derived from input_ids.
Defaults to 5 to match the MultiMolecule streamline DNA alphabet (ACGTN).
|
5
|
hidden_size
|
int
|
Number of channels in the convolutional backbone.
|
64
|
stem_kernel_size
|
int
|
Kernel size of the first (motif) convolution.
|
25
|
num_dilated_layers
|
int
|
Number of dilated residual convolution blocks following the stem.
|
9
|
dilated_kernel_size
|
int
|
Kernel size of each dilated residual convolution.
|
3
|
profile_kernel_size
|
int
|
Kernel size of the transposed convolution in the profile branch.
|
25
|
num_tasks
|
int
|
Number of prediction tasks (e.g. transcription factors).
|
4
|
num_strands
|
int
|
Number of strands predicted per task.
|
2
|
hidden_act
|
str
|
The non-linear activation function (function or string) in the backbone.
|
'relu'
|
count_loss_weight
|
float
|
The weight applied to the count regression loss when combining it with the profile multinomial loss.
|
1.0
|
head
|
HeadConfig | None
|
The configuration of the generic token prediction head. If not provided, it defaults to regression.
|
None
|
output_hidden_states
|
bool
|
Whether to output the backbone hidden states.
|
False
|
Examples:
| Python Console Session |
|---|
| >>> from multimolecule import BpNetConfig, BpNetModel
>>> # Initializing a BPNet multimolecule/bpnet style configuration
>>> configuration = BpNetConfig()
>>> # Initializing a model (with random weights) from the multimolecule/bpnet style configuration
>>> model = BpNetModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
|
Source code in multimolecule/models/bpnet/configuration_bpnet.py
| Python |
|---|
| class BpNetConfig(PreTrainedConfig):
r"""
This is the configuration class to store the configuration of a
[`BpNetModel`][multimolecule.models.BpNetModel]. It is used to instantiate a BPNet model according to the specified
arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar
configuration to that of the BPNet [BPNet-OSKN](https://zenodo.org/records/4294904) architecture.
Configuration objects inherit from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig] and can be used to
control the model outputs. Read the documentation from [`PreTrainedConfig`][multimolecule.models.PreTrainedConfig]
for more information.
BPNet predicts a single base-resolution signal task whose output is factorized into two terminal branches that
share the dilated-convolution backbone:
- a *profile* branch producing per-position multinomial logits of shape
`(batch_size, sequence_length, num_tasks * num_strands)`;
- a *count* branch producing a scalar per task and strand of shape `(batch_size, num_tasks * num_strands)`.
Args:
vocab_size:
Vocabulary size of the BPNet model. Defines the number of one-hot input channels derived from `input_ids`.
Defaults to 5 to match the MultiMolecule `streamline` DNA alphabet (`ACGTN`).
hidden_size:
Number of channels in the convolutional backbone.
stem_kernel_size:
Kernel size of the first (motif) convolution.
num_dilated_layers:
Number of dilated residual convolution blocks following the stem.
dilated_kernel_size:
Kernel size of each dilated residual convolution.
profile_kernel_size:
Kernel size of the transposed convolution in the profile branch.
num_tasks:
Number of prediction tasks (e.g. transcription factors).
num_strands:
Number of strands predicted per task.
hidden_act:
The non-linear activation function (function or string) in the backbone.
count_loss_weight:
The weight applied to the count regression loss when combining it with the profile multinomial loss.
head:
The configuration of the generic token prediction head. If not provided, it defaults to regression.
output_hidden_states:
Whether to output the backbone hidden states.
Examples:
>>> from multimolecule import BpNetConfig, BpNetModel
>>> # Initializing a BPNet multimolecule/bpnet style configuration
>>> configuration = BpNetConfig()
>>> # Initializing a model (with random weights) from the multimolecule/bpnet style configuration
>>> model = BpNetModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "bpnet"
def __init__(
self,
vocab_size: int = 5,
hidden_size: int = 64,
stem_kernel_size: int = 25,
num_dilated_layers: int = 9,
dilated_kernel_size: int = 3,
profile_kernel_size: int = 25,
num_tasks: int = 4,
num_strands: int = 2,
hidden_act: str = "relu",
count_loss_weight: float = 1.0,
head: HeadConfig | None = None,
output_hidden_states: bool = False,
bos_token_id: int | None = None,
eos_token_id: int | None = None,
pad_token_id: int = 4,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.bos_token_id = bos_token_id # type: ignore[assignment]
self.eos_token_id = eos_token_id # type: ignore[assignment]
if num_dilated_layers < 1:
raise ValueError(f"num_dilated_layers ({num_dilated_layers}) must be at least 1.")
if num_tasks < 1:
raise ValueError(f"num_tasks ({num_tasks}) must be at least 1.")
if num_strands < 1:
raise ValueError(f"num_strands ({num_strands}) must be at least 1.")
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.stem_kernel_size = stem_kernel_size
self.num_dilated_layers = num_dilated_layers
self.dilated_kernel_size = dilated_kernel_size
self.profile_kernel_size = profile_kernel_size
self.num_tasks = num_tasks
self.num_strands = num_strands
self.hidden_act = hidden_act
self.count_loss_weight = count_loss_weight
if head is None:
head = HeadConfig(problem_type="regression")
else:
head = HeadConfig(head)
if head.problem_type is None:
head.problem_type = "regression"
self.head = head
self.output_hidden_states = output_hidden_states
@property
def num_labels(self) -> int:
return self.num_tasks * self.num_strands
@num_labels.setter
def num_labels(self, value: int) -> None:
# ``PretrainedConfig.__init__`` assigns ``num_labels``; BPNet derives it from
# ``num_tasks * num_strands`` so the assignment is intentionally ignored.
pass
|
BpNetForProfilePrediction
Bases: BpNetPreTrainedModel
BPNet with the factorized profile/count head for base-resolution signal prediction.
This is a token/positional-prediction model: it is registered with the token AutoModel family and predicts a
per-position value for every input nucleotide. The single base-resolution task is factorized into two terminal
branches sharing the backbone:
profile_logits: per-position multinomial logits of shape (batch_size, sequence_length, num_labels);
count_logits: a scalar per task and strand of shape (batch_size, num_labels),
where num_labels = num_tasks * num_strands. Use [postprocess][multimolecule.models.BpNetForProfilePrediction.
postprocess] to recombine them into the usable base-resolution track.
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import BpNetConfig, BpNetForProfilePrediction, DnaTokenizer
>>> config = BpNetConfig()
>>> model = BpNetForProfilePrediction(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/bpnet")
>>> input = tokenizer("ACGTNACGTN", return_tensors="pt")
>>> output = model(**input)
>>> output["profile_logits"].shape
torch.Size([1, 10, 8])
>>> output["count_logits"].shape
torch.Size([1, 8])
>>> track = model.postprocess(output)
>>> track.shape
torch.Size([1, 10, 8])
|
Source code in multimolecule/models/bpnet/modeling_bpnet.py
| Python |
|---|
| class BpNetForProfilePrediction(BpNetPreTrainedModel):
"""
BPNet with the factorized profile/count head for base-resolution signal prediction.
This is a token/positional-prediction model: it is registered with the token AutoModel family and predicts a
per-position value for every input nucleotide. The single base-resolution task is factorized into two terminal
branches sharing the backbone:
- `profile_logits`: per-position multinomial logits of shape `(batch_size, sequence_length, num_labels)`;
- `count_logits`: a scalar per task and strand of shape `(batch_size, num_labels)`,
where `num_labels = num_tasks * num_strands`. Use [`postprocess`][multimolecule.models.BpNetForProfilePrediction.
postprocess] to recombine them into the usable base-resolution track.
Examples:
>>> import torch
>>> from multimolecule import BpNetConfig, BpNetForProfilePrediction, DnaTokenizer
>>> config = BpNetConfig()
>>> model = BpNetForProfilePrediction(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/bpnet")
>>> input = tokenizer("ACGTNACGTN", return_tensors="pt")
>>> output = model(**input)
>>> output["profile_logits"].shape
torch.Size([1, 10, 8])
>>> output["count_logits"].shape
torch.Size([1, 8])
>>> track = model.postprocess(output)
>>> track.shape
torch.Size([1, 10, 8])
"""
def __init__(self, config: BpNetConfig):
super().__init__(config)
self.model = BpNetModel(config)
self.profile_count_head = BpNetProfileCountHead(config)
# Initialize weights and apply final processing
self.post_init()
@property
def output_channels(self) -> list[str]:
id2label = getattr(self.config, "id2label", None)
if id2label is not None and any(
str(id2label.get(i, f"task_{i}")) != f"LABEL_{i}" for i in range(self.config.num_tasks)
):
tasks = [str(id2label.get(i, f"task_{i}")) for i in range(self.config.num_tasks)]
else:
tasks = [f"task_{index}" for index in range(self.config.num_tasks)]
if self.config.num_strands == 2:
strands = ["plus", "minus"]
else:
strands = [f"strand_{index}" for index in range(self.config.num_strands)]
return [f"{task}_{strand}" for task in tasks for strand in strands]
@merge_with_config_defaults
@can_return_tuple
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
labels: dict[str, Tensor] | tuple[Tensor, Tensor] | None = None,
# labels: dict {"profile": (batch, seq_len, num_labels) int counts,
# "count": (batch, num_labels) log total counts}
# or a (profile, count) tuple in the same order.
**kwargs: Unpack[TransformersKwargs],
) -> BpNetProfilePredictorOutput:
outputs = self.model(
input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
return_dict=True,
**kwargs,
)
head_output = self.profile_count_head(outputs.last_hidden_state, labels)
return BpNetProfilePredictorOutput(
loss=head_output.loss,
profile_logits=head_output.profile_logits,
count_logits=head_output.count_logits,
hidden_states=outputs.hidden_states,
)
def postprocess(self, outputs: BpNetProfilePredictorOutput | ModelOutput) -> Tensor:
r"""
Recombine the factorized profile and count branches into the usable base-resolution track.
BPNet does not predict the signal track directly; the profile branch predicts the *shape* (a per-position
multinomial distribution) and the count branch predicts the *total magnitude* (in log space). The usable
prediction recombines them as `softmax(profile_logits, positions) * exp(count_logits)`.
Args:
outputs: The output of [`BpNetForProfilePrediction`][multimolecule.models.BpNetForProfilePrediction].
Returns:
The predicted base-resolution track of shape `(batch_size, sequence_length, num_labels)`.
"""
profile_logits = outputs["profile_logits"]
count_logits = outputs["count_logits"]
profile = F.softmax(profile_logits, dim=1)
return profile * torch.exp(count_logits).unsqueeze(1)
|
postprocess
Recombine the factorized profile and count branches into the usable base-resolution track.
BPNet does not predict the signal track directly; the profile branch predicts the shape (a per-position
multinomial distribution) and the count branch predicts the total magnitude (in log space). The usable
prediction recombines them as softmax(profile_logits, positions) * exp(count_logits).
Parameters:
Returns:
| Type |
Description |
Tensor
|
The predicted base-resolution track of shape (batch_size, sequence_length, num_labels).
|
Source code in multimolecule/models/bpnet/modeling_bpnet.py
| Python |
|---|
| def postprocess(self, outputs: BpNetProfilePredictorOutput | ModelOutput) -> Tensor:
r"""
Recombine the factorized profile and count branches into the usable base-resolution track.
BPNet does not predict the signal track directly; the profile branch predicts the *shape* (a per-position
multinomial distribution) and the count branch predicts the *total magnitude* (in log space). The usable
prediction recombines them as `softmax(profile_logits, positions) * exp(count_logits)`.
Args:
outputs: The output of [`BpNetForProfilePrediction`][multimolecule.models.BpNetForProfilePrediction].
Returns:
The predicted base-resolution track of shape `(batch_size, sequence_length, num_labels)`.
"""
profile_logits = outputs["profile_logits"]
count_logits = outputs["count_logits"]
profile = F.softmax(profile_logits, dim=1)
return profile * torch.exp(count_logits).unsqueeze(1)
|
BpNetForTokenPrediction
Bases: BpNetPreTrainedModel
BPNet backbone with a randomly initialized generic token-prediction head.
This class is intended for downstream fine-tuning from the BPNet backbone. It returns the standard
[TokenPredictorOutput][multimolecule.models.TokenPredictorOutput] with a single logits field, unlike
BpNetForProfilePrediction, which exposes the published
factorized profile_logits / count_logits task head.
Examples:
| Python Console Session |
|---|
| >>> import torch
>>> from multimolecule import BpNetConfig, BpNetForTokenPrediction
>>> config = BpNetConfig()
>>> model = BpNetForTokenPrediction(config)
>>> input_ids = torch.randint(config.vocab_size, (1, 16))
>>> output = model(input_ids)
>>> output["logits"].shape
torch.Size([1, 16, 8])
|
Source code in multimolecule/models/bpnet/modeling_bpnet.py
| Python |
|---|
| class BpNetForTokenPrediction(BpNetPreTrainedModel):
"""
BPNet backbone with a randomly initialized generic token-prediction head.
This class is intended for downstream fine-tuning from the BPNet backbone. It returns the standard
[`TokenPredictorOutput`][multimolecule.models.TokenPredictorOutput] with a single `logits` field, unlike
[`BpNetForProfilePrediction`][multimolecule.models.BpNetForProfilePrediction], which exposes the published
factorized `profile_logits` / `count_logits` task head.
Examples:
>>> import torch
>>> from multimolecule import BpNetConfig, BpNetForTokenPrediction
>>> config = BpNetConfig()
>>> model = BpNetForTokenPrediction(config)
>>> input_ids = torch.randint(config.vocab_size, (1, 16))
>>> output = model(input_ids)
>>> output["logits"].shape
torch.Size([1, 16, 8])
"""
def __init__(self, config: BpNetConfig):
super().__init__(config)
self.model = BpNetModel(config)
self.token_head = TokenPredictionHead(config)
self.head_config = self.token_head.config
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
labels: Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[Tensor, ...] | TokenPredictorOutput:
head_attention_mask = attention_mask
if input_ids is None and inputs_embeds is not None and head_attention_mask is None:
if isinstance(inputs_embeds, NestedTensor):
head_attention_mask = inputs_embeds.mask
else:
head_attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.int, device=inputs_embeds.device)
outputs = self.model(
input_ids,
attention_mask=head_attention_mask,
inputs_embeds=inputs_embeds,
return_dict=True,
**kwargs,
)
output = self.token_head(outputs, head_attention_mask, input_ids, labels)
return TokenPredictorOutput(
loss=output.loss,
logits=output.logits,
hidden_states=outputs.hidden_states,
)
|
BpNetModel
Bases: BpNetPreTrainedModel
The bare BPNet dilated-convolution backbone producing per-position features.
Examples:
| Python Console Session |
|---|
| >>> from multimolecule import BpNetConfig, BpNetModel, DnaTokenizer
>>> config = BpNetConfig()
>>> model = BpNetModel(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/bpnet")
>>> input = tokenizer("ACGTNACGTN", return_tensors="pt")
>>> output = model(**input)
>>> output["last_hidden_state"].shape
torch.Size([1, 10, 64])
|
Source code in multimolecule/models/bpnet/modeling_bpnet.py
| Python |
|---|
| class BpNetModel(BpNetPreTrainedModel):
"""
The bare BPNet dilated-convolution backbone producing per-position features.
Examples:
>>> from multimolecule import BpNetConfig, BpNetModel, DnaTokenizer
>>> config = BpNetConfig()
>>> model = BpNetModel(config)
>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/bpnet")
>>> input = tokenizer("ACGTNACGTN", return_tensors="pt")
>>> output = model(**input)
>>> output["last_hidden_state"].shape
torch.Size([1, 10, 64])
"""
def __init__(self, config: BpNetConfig):
super().__init__(config)
self.embeddings = BpNetEmbedding(config)
self.encoder = BpNetEncoder(config)
# Initialize weights and apply final processing
self.post_init()
@merge_with_config_defaults
@capture_outputs
def forward(
self,
input_ids: Tensor | NestedTensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | NestedTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> BpNetModelOutput:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
if input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if isinstance(input_ids, NestedTensor):
if attention_mask is None:
attention_mask = input_ids.mask
input_ids = input_ids.tensor
if isinstance(inputs_embeds, NestedTensor):
if attention_mask is None:
attention_mask = inputs_embeds.mask
inputs_embeds = inputs_embeds.tensor
embedding_output = self.embeddings(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
)
encoder_output = self.encoder(embedding_output, **kwargs)
last_hidden_state = encoder_output.last_hidden_state.transpose(1, 2)
return BpNetModelOutput(
last_hidden_state=last_hidden_state,
hidden_states=encoder_output.hidden_states,
)
|
BpNetModelOutput
dataclass
Bases: ModelOutput
Base class for outputs of the BPNet backbone.
Parameters:
| Name |
Type |
Description |
Default |
last_hidden_state
|
`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`
|
Per-position backbone features.
|
None
|
Source code in multimolecule/models/bpnet/modeling_bpnet.py
| Python |
|---|
| @dataclass
class BpNetModelOutput(ModelOutput):
"""
Base class for outputs of the BPNet backbone.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Per-position backbone features.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the stem output plus one per dilated layer) of shape `(batch_size,
sequence_length, hidden_size)`.
"""
last_hidden_state: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
|
BpNetPreTrainedModel
Bases: PreTrainedModel
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
Source code in multimolecule/models/bpnet/modeling_bpnet.py
| Python |
|---|
| class BpNetPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BpNetConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_can_record_outputs: dict[str, Any] | None = None
_no_split_modules = ["BpNetLayer"]
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
# Use transformers.initialization wrappers (imported as `init`); they check the
# `_is_hf_initialized` flag so they don't clobber tensors loaded from a checkpoint.
if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
init.kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(module.bias, -bound, bound)
|
BpNetProfilePredictorOutput
dataclass
Bases: ModelOutput
Base class for outputs of BpNetForProfilePrediction.
The standard single-logits predictor dataclasses cannot express BPNet’s factorized output, so this model-local
dataclass exposes the two terminal branches separately. Use
[postprocess][multimolecule.models.BpNetForProfilePrediction.postprocess] to recombine them.
Parameters:
| Name |
Type |
Description |
Default |
loss
|
`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided
|
Composite multinomial-NLL (profile) + weighted count-MSE (count) loss.
|
None
|
profile_logits
|
`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`
|
Per-position multinomial logits, where num_labels = num_tasks * num_strands.
|
None
|
count_logits
|
`torch.FloatTensor` of shape `(batch_size, num_labels)`
|
Per task/strand log-count scalars.
|
None
|
Source code in multimolecule/models/bpnet/modeling_bpnet.py
| Python |
|---|
| @dataclass
class BpNetProfilePredictorOutput(ModelOutput):
"""
Base class for outputs of [`BpNetForProfilePrediction`][multimolecule.models.BpNetForProfilePrediction].
The standard single-`logits` predictor dataclasses cannot express BPNet's factorized output, so this model-local
dataclass exposes the two terminal branches separately. Use
[`postprocess`][multimolecule.models.BpNetForProfilePrediction.postprocess] to recombine them.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Composite multinomial-NLL (profile) + weighted count-MSE (count) loss.
profile_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`):
Per-position multinomial logits, where `num_labels = num_tasks * num_strands`.
count_logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
Per task/strand log-count scalars.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
when `config.output_hidden_states=True`):
Tuple of backbone hidden states of shape `(batch_size, sequence_length, hidden_size)`.
"""
loss: torch.FloatTensor | None = None
profile_logits: torch.FloatTensor | None = None
count_logits: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
|