跳转至

Runner

multimolecule.runner.Runner

Bases: Runner

Source code in multimolecule/runner/runner.py
Python
@RUNNERS.register("multimolecule", default=True)
class Runner(dl.Runner):

    config: Config
    model: Model

    def __init__(self, config: Config):
        if not config.pretrained:
            raise ValueError("A pretrained model must be specified via config.pretrained")
        # We do not want to check if it actually exists, because it might be a local directory
        validate_repo_id(config.pretrained)
        super().__init__(config)
        # must build tokenizer before datasets
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained)
        self.datasets = self.build_datasets()
        self.dataloaders = self.build_dataloaders()
        self.model = MODELS.build(**self.network)
        self.model = self.model.to(self.device)
        if self.config.training:
            parameters = self.model.trainable_parameters(**self.config.optim)
            self.optimizer = OPTIMIZERS.build(parameters, **self.config.optim)
            if "sched" in self.config:
                step_with_optimizer = self.config.platform != "deepspeed"
                self.scheduler = SCHEDULERS.build(
                    self.optimizer,
                    total_steps=self.total_steps,
                    step_with_optimizer=step_with_optimizer,
                    **self.config.sched,
                )
            if "ema" in self.config:
                ema_enabled = self.config.ema.pop("enabled", False)
                if ema_enabled:
                    ema.check()
                    self.ema = EMA(self.model, include_online_model=False, **self.config.ema)
                    self.ema.add_to_optimizer_post_step_hook(self.optimizer)
                self.config.ema.enabled = ema_enabled
        self.train_metrics = self.build_train_metrics()
        self.evaluate_metrics = self.build_evaluate_metrics()

    def __post_init__(self):
        if self.config.platform != "deepspeed":
            if "checkpoint" in self.config:
                self.load_checkpoint(self.config.checkpoint)
            if self.distributed:
                self.model = nn.parallel.DistributedDataParallel(
                    self.model, find_unused_parameters=True, bucket_cap_mb=32, gradient_as_bucket_view=True
                )
        super().__post_init__()
        if self.config.platform == "deepspeed" and "checkpoint" in self.config:
            self.load_checkpoint(self.config.checkpoint)
        self.yaml(os.path.join(self.dir, "trainer.yaml"))

    def train(self, train_splits: list[str] | None = None, evaluate_splits: list[str] | None = None) -> NestedDict:
        r"""
        Train the model on the specified data splits.
        """
        return super().train(train_splits, evaluate_splits)

    def train_step(self, data) -> Tuple[Any, torch.Tensor]:
        data = to_device(data, self.device)
        with self.autocast(), self.accumulate():
            pred = self.model(**data)
            self.advance(pred["loss"])
            self.metric_fn(pred, data)
        return pred, pred["loss"]

    def evaluate(self, evaluate_splits: list[str] | None = None) -> NestedDict:
        r"""
        Evaluate the model on specified data splits.
        """
        return super().evaluate(evaluate_splits)

    def evaluate_step(self, data) -> Tuple[Any, torch.Tensor]:
        data = to_device(data, self.device)
        model = self.ema or self.model
        output = model(**data)
        self.metric_fn(output, data)
        return output, output["loss"]

    @torch.inference_mode()
    def infer(self, split: str = "inf") -> NestedDict | FlatDict | list:
        r"""
        Infer predictions with the model on the specified data split.
        """

        self.mode = "inf"  # type: ignore
        loader = self.dataloaders[split]
        preds = FlatDict()
        labels = FlatDict()
        model = self.ema or self.model
        for _, data in tqdm(enumerate(loader), total=len(loader)):  # noqa: F402
            data = to_device(data, self.device)
            pred = model(**data)
            if isinstance(pred, tuple):
                pred, loss = pred
            for task, p in pred.items():
                preds[task].extend(p["logits"].squeeze(-1).tolist())
                if task in data:
                    labels[task].extend(data[task].squeeze(-1).tolist())

        if self.distributed:
            torch.cuda.synchronize()
            for task in preds.keys():
                preds[task] = self.gather_for_metrics(preds[task])
            for task in labels.keys():
                labels[task] = self.gather_for_metrics(labels[task])
        if labels:
            if len(preds) == 1:
                return FlatDict(predict=next(iter(preds.values())), label=next(iter(labels.values())))
            return NestedDict({task: {"predict": preds[task], "label": labels[task]} for task in preds})
        if len(preds) == 1:
            return next(iter(preds.values()))
        return preds

    def metric_fn(self, pred, data):
        self.metrics.update(pred["logits"], data["labels"])
        self.meters.update({"loss": pred["loss"]})

    @cached_property
    def task(self):
        if not self.datasets:
            raise ValueError("No datasets found")
        tasks = self.datasets.train.tasks if "train" in self.datasets else next(iter(self.datasets.values())).tasks
        if len(tasks) != 1:
            raise ValueError(f"Expected exactly one task, got {len(tasks)}")
        return next(iter(tasks.values()))

    @cached_property
    def network(self):
        head = HeadConfig(num_labels=self.task.num_labels, problem_type=self.task.type, type=self.task.level)
        if "head" not in self.config.network:
            self.config.network.head = NestedDict(head)
        else:
            self.config.network.head.merge(head, overwrite=False)
        return self.config.network

    def build_datasets(self) -> NestedDict[str, Dataset]:
        return self._build_dataset(self.config.data)

    def _build_dataset(self, config: NestedDict, name: str | None = None) -> NestedDict:
        root = config.pop("root", None)
        if root is None:
            raise ValueError(f"Unable to build dataset for {config}, root is not specified.")
        local_root = Path(root).expanduser().resolve()
        if name is None:
            name = "/".join(local_root.parts[-2:])
        ratio = config.pop("ratio", None)
        try:
            is_hf_dataset = bool(get_dataset_split_names(root))
        except FileNotFoundError:
            is_hf_dataset = False
        is_local_dataset = local_root.is_dir()
        dataset = None
        if is_local_dataset:
            dataset = self._build_local_dataset(config, str(local_root), ratio, name)
        elif is_hf_dataset:
            dataset = self._build_hf_dataset(config, root, ratio, name)
        else:
            raise ValueError(
                f"Dataset root '{root}' is invalid. It must be either:\n"
                f"  - A valid Hugging Face dataset repository ID\n"
                f"  - A path to an existing local directory"
            )
        if not dataset:
            raise ValueError(f"No datasets built. This is likely due to missing data paths in {config}.")
        config.root = root
        config.ratio = ratio
        return dataset

    def _build_local_dataset(self, config: NestedDict, root: str, ratio: float | None, name: str) -> NestedDict:
        r"""Build dataset from local directory."""
        train_splits = []
        other_splits = []
        splits = [k for k in defaults.DATASET_SPLITS if config.get(k) is not None]
        if splits:
            train_splits = [key for key in splits if key.startswith(defaults.TRAIN_SPLITS)]
            other_splits = [key for key in splits if key not in train_splits]
        # Automatic find splits for local 🤗 dataset
        if not splits:
            splits = get_dataset_split_names(root)
            for split, data_files in DataFilesDict.from_local_or_remote(get_data_patterns(root), root).items():
                split = str(split)
                if len(data_files) > 1:
                    for idx, data_file in enumerate(data_files):
                        config[f"{split}-{str(idx).zfill(5)}-of-{str(len(data_files)).zfill(5)}"] = data_file
                else:
                    config[split] = data_files[0]
                if split in defaults.TRAIN_SPLITS:
                    train_splits.append(split)
                else:
                    other_splits.append(split)
        if not splits:
            raise ValueError(f"No splits found for dataset {name}. Please specify at least one split in the config.")
        print(f"Building local dataset {name}")
        dataset = NestedDict()
        ignored_keys = train_splits + other_splits
        dataset_factory = partial(
            DATASETS.build,
            tokenizer=self.tokenizer,
            auto_rename_label_col=True,
            **{k: v for k, v in config.items() if k not in ignored_keys},
        )
        if self.config.training:
            for split in train_splits:
                dataset[split] = dataset_factory(
                    os.path.join(root, config[split]), split=split, train=True, ratio=ratio
                )
        elif train_splits:
            warn("Training is disabled, ignoring training splits", RuntimeWarning, stacklevel=2)
        for split in other_splits:
            dataset[split] = dataset_factory(os.path.join(root, config[split]), split=split, train=False)
        return dataset

    def _build_hf_dataset(self, config: NestedDict, root: str, ratio: float | None, name: str) -> NestedDict:
        r"""Build dataset from HuggingFace datasets."""
        splits = [k for k in defaults.DATASET_SPLITS if config.get(k) is not None] or get_dataset_split_names(root)
        train_splits = [key for key in splits if key.startswith(defaults.TRAIN_SPLITS)]
        other_splits = [key for key in splits if key not in train_splits]
        print(f"Building HuggingFace dataset {name}")
        dataset = NestedDict()
        ignored_keys = train_splits + other_splits
        dataset_factory = partial(
            DATASETS.build,
            tokenizer=self.tokenizer,
            **{k: v for k, v in config.items() if k not in ignored_keys},
        )
        if self.config.training:
            for split in train_splits:
                dataset[split] = dataset_factory(root, split=split, train=True, ratio=ratio)
        elif train_splits:
            warn("Training is disabled, ignoring training splits", RuntimeWarning, stacklevel=2)
        for split in other_splits:
            dataset[split] = dataset_factory(root, split=split, train=False)
        return dataset

    def build_dataloaders(self) -> NestedDict[str, DataLoader]:
        dataloaders = NestedDict()
        datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders}
        default_kwargs = self.config.setdefault("dataloader", DataloaderConfig())
        dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs})
        for k, d in datasets.items():
            dataloader_kwargs_ = dataloader_kwargs.setdefault(k, default_kwargs)
            dataloader_kwargs_.merge(default_kwargs, overwrite=False)
            dataloaders[k] = DataLoader(
                d, distributed=self.distributed, collate_fn=self.collate_fn, **dataloader_kwargs_
            )
        return dataloaders

    def build_train_metrics(self) -> MetricMeters:
        # Return a MetricMeters object for training set as we do not need very precise metrics
        # and it is more efficient to use MetricMeters
        return MetricMeters(METRICS.build(type=self.task.type, num_labels=self.task.num_labels))

    def build_evaluate_metrics(self) -> Metrics:
        return METRICS.build(type=self.task.type, num_labels=self.task.num_labels)

    @staticmethod
    def collate_fn(batch) -> dict:
        return batch

train

Python
train(
    train_splits: list[str] | None = None,
    evaluate_splits: list[str] | None = None,
) -> NestedDict

Train the model on the specified data splits.

Source code in multimolecule/runner/runner.py
Python
def train(self, train_splits: list[str] | None = None, evaluate_splits: list[str] | None = None) -> NestedDict:
    r"""
    Train the model on the specified data splits.
    """
    return super().train(train_splits, evaluate_splits)

evaluate

Python
evaluate(
    evaluate_splits: list[str] | None = None,
) -> NestedDict

Evaluate the model on specified data splits.

Source code in multimolecule/runner/runner.py
Python
def evaluate(self, evaluate_splits: list[str] | None = None) -> NestedDict:
    r"""
    Evaluate the model on specified data splits.
    """
    return super().evaluate(evaluate_splits)

infer

Python
infer(split: str = 'inf') -> NestedDict | FlatDict | list

Infer predictions with the model on the specified data split.

Source code in multimolecule/runner/runner.py
Python
@torch.inference_mode()
def infer(self, split: str = "inf") -> NestedDict | FlatDict | list:
    r"""
    Infer predictions with the model on the specified data split.
    """

    self.mode = "inf"  # type: ignore
    loader = self.dataloaders[split]
    preds = FlatDict()
    labels = FlatDict()
    model = self.ema or self.model
    for _, data in tqdm(enumerate(loader), total=len(loader)):  # noqa: F402
        data = to_device(data, self.device)
        pred = model(**data)
        if isinstance(pred, tuple):
            pred, loss = pred
        for task, p in pred.items():
            preds[task].extend(p["logits"].squeeze(-1).tolist())
            if task in data:
                labels[task].extend(data[task].squeeze(-1).tolist())

    if self.distributed:
        torch.cuda.synchronize()
        for task in preds.keys():
            preds[task] = self.gather_for_metrics(preds[task])
        for task in labels.keys():
            labels[task] = self.gather_for_metrics(labels[task])
    if labels:
        if len(preds) == 1:
            return FlatDict(predict=next(iter(preds.values())), label=next(iter(labels.values())))
        return NestedDict({task: {"predict": preds[task], "label": labels[task]} for task in preds})
    if len(preds) == 1:
        return next(iter(preds.values()))
    return preds