Skip to content

Stat

multimolecule.api.stat.Result

Bases: NestedDict

Source code in multimolecule/api/stat.py
Python
class Result(NestedDict):
    pretrained: str
    id: str
    seed: int
    epoch: int
    validation: NestedDict
    test: NestedDict

multimolecule.api.stat.get_result_stat

Python
get_result_stat(
    experiment_root: str, remove_empty: bool = True
) -> list[Result]
Source code in multimolecule/api/stat.py
Python
def get_result_stat(experiment_root: str, remove_empty: bool = True) -> list[Result]:
    results = []
    for root, _, files in tqdm(os.walk(experiment_root)):
        if "run.log" not in files:
            continue
        if "best.json" not in files:
            if remove_empty:
                shutil.rmtree(root)
            continue
        best = NestedDict.from_json(os.path.join(root, "best.json"))
        if "index" not in best:
            if remove_empty:
                shutil.rmtree(root)
            continue
        config = NestedDict.from_yaml(os.path.join(root, "trainer.yaml"))
        pretrained = config.pretrained.split("/")[-1]
        result = Result(id=best.id, pretrained=pretrained, seed=config.seed)
        validation = best.get("validation", best.get("val", NestedDict()))
        test = best.get("test", NestedDict())
        result.validation = NestedDict(
            {
                key: format(mean(value) if isinstance(value, list) else value, ".8f")
                for key, value in validation.all_items()
            }
        )
        result.test = NestedDict(
            {key: format(mean(value) if isinstance(value, list) else value, ".8f") for key, value in test.all_items()}
        )
        result.epoch = best.index
        for key in ("validation.time", "test.time", "validation.loss", "test.loss", "validation.lr", "test.lr"):
            result.pop(key, None)
        results.append(result)
    if remove_empty:
        for root, dirs, files in os.walk(experiment_root):
            if not files and not dirs:
                os.rmdir(root)
        for root, dirs, files in os.walk(experiment_root):
            if not files and not dirs:
                os.rmdir(root)
    results.sort(key=lambda x: (x.pretrained, x.seed, x.id))
    return results

multimolecule.api.stat.write_result_stat

Python
write_result_stat(results: list[Result], path: str) -> None
Source code in multimolecule/api/stat.py
Python
def write_result_stat(results: list[Result], path: str) -> None:
    rows = [dict(result.all_items()) for result in results]
    df = pd.DataFrame.from_dict(rows)
    df.insert(len(df.keys()) - 1, "comment", "")
    df.fillna("")
    df.to_csv(path, index=False)

multimolecule.api.stat.Config

Bases: Config

Source code in multimolecule/api/stat.py
Python
class Config(chanfig.Config):
    experiment_root: str = "experiments"
    out_path: str = "result.csv"
    remove_empty: bool = True