class Config(chanfig.Config):
name: str
seed: int = 1016
runner: str = "multimolecule"
pretrained: Optional[str]
use_pretrained: bool = True
transformers: PretrainedConfig
epoch_end: int = 20
data: Union[DataConfig, str]
tensorboard: bool = True
save_interval: int = 10
platform: str = "auto"
allow_tf32: bool = True
reduced_precision_reduction: bool = False
def __init__(self, *args, **kwargs):
self.dataloader = DataloaderConfig()
self.network = NetworkConfig()
self.optim = OptimConfig()
self.sched = SchedulerConfig()
self.ema = EmaConfig()
super().__init__(*args, **kwargs)
def post(self):
super().post()
if "pretrained" not in self and "checkpoint" not in self:
raise ValueError("Either one of `pretrained` or `checkpoint` must be specified")
if "data" not in self:
raise ValueError("`data` must be specified")
if "pretrained" in self:
self["network.backbone.sequence.name"] = self.get("pretrained")
self.name = str(self.name) if "name" in self else self.get_name(self.get("pretrained", "null"))
self["network.backbone.sequence.use_pretrained"] = self.use_pretrained
def get_name(self, pretrained: str) -> str:
if os.path.exists(pretrained):
path = Path(pretrained)
if os.path.isfile(pretrained):
pretrained = str(path.relative_to(path.parents[1]).with_suffix(""))
else:
pretrained = path.stem
name = pretrained.replace("/", "--")
if self.get("optim"):
optim_name = self.optim.get("type", "no")
name += f"-{self.optim.lr}@{optim_name}"
return name + f"-{self.seed}"
def set(self, key: str, value: Any):
if key == "data" and isinstance(value, str):
value = DataConfig(root=value)
super().set(key, value)