def get_result_stat(experiment_root: str, remove_empty: bool = True) -> list[Result]:
results = []
for root, _, files in tqdm(os.walk(experiment_root)):
if "run.log" not in files:
continue
if "best.json" not in files:
if remove_empty:
shutil.rmtree(root)
continue
best = NestedDict.from_json(os.path.join(root, "best.json"))
if "index" not in best:
if remove_empty:
shutil.rmtree(root)
continue
config = NestedDict.from_yaml(os.path.join(root, "trainer.yaml"))
pretrained = config.pretrained.split("/")[-1]
result = Result(id=best.id, pretrained=pretrained, seed=config.seed)
validation = best.get("validation", best.get("val", NestedDict()))
test = best.get("test", NestedDict())
result.validation = NestedDict(
{
key: format(mean(value) if isinstance(value, list) else value, ".8f")
for key, value in validation.all_items()
}
)
result.test = NestedDict(
{key: format(mean(value) if isinstance(value, list) else value, ".8f") for key, value in test.all_items()}
)
result.epoch = best.index
for key in ("validation.time", "test.time", "validation.loss", "test.loss", "validation.lr", "test.lr"):
result.pop(key, None)
results.append(result)
if remove_empty:
for root, dirs, files in os.walk(experiment_root):
if not files and not dirs:
os.rmdir(root)
for root, dirs, files in os.walk(experiment_root):
if not files and not dirs:
os.rmdir(root)
results.sort(key=lambda x: (x.pretrained, x.seed, x.id))
return results