Skip to content

Config

multimolecule.runner.Config

Bases: Config

Source code in multimolecule/runner/config.py
Python
class Config(chanfig.Config):
    name: str
    seed: int = 1016
    runner: str = "multimolecule"

    pretrained: Optional[str]
    use_pretrained: bool = True
    transformers: PretrainedConfig
    epoch_end: int = 20

    data: Union[DataConfig, str]

    tensorboard: bool = True
    save_interval: int = 10

    platform: str = "auto"
    allow_tf32: bool = True
    reduced_precision_reduction: bool = False

    def __init__(self, *args, **kwargs):
        self.dataloader = DataloaderConfig()
        self.network = NetworkConfig()
        self.optim = OptimConfig()
        self.sched = SchedulerConfig()
        self.ema = EmaConfig()
        super().__init__(*args, **kwargs)

    def post(self):
        super().post()
        if "pretrained" not in self and "checkpoint" not in self:
            raise ValueError("Either one of `pretrained` or `checkpoint` must be specified")
        if "data" not in self:
            raise ValueError("`data` must be specified")
        if "pretrained" in self:
            self["network.backbone.sequence.name"] = self.get("pretrained")
        self.name = str(self.name) if "name" in self else self.get_name(self.get("pretrained", "null"))
        self["network.backbone.sequence.use_pretrained"] = self.use_pretrained

    def get_name(self, pretrained: str) -> str:
        if os.path.exists(pretrained):
            path = Path(pretrained)
            if os.path.isfile(pretrained):
                pretrained = str(path.relative_to(path.parents[1]).with_suffix(""))
            else:
                pretrained = path.stem
        name = pretrained.replace("/", "--")
        if self.get("optim"):
            optim_name = self.optim.get("type", "no")
            name += f"-{self.optim.lr}@{optim_name}"
        return name + f"-{self.seed}"

    def set(self, key: str, value: Any):
        if key == "data" and isinstance(value, str):
            value = DataConfig(root=value)
        super().set(key, value)