diff --git a/tests/torch/test_utils.py b/tests/torch/test_utils.py index ea46482..114eb18 100644 --- a/tests/torch/test_utils.py +++ b/tests/torch/test_utils.py @@ -90,6 +90,14 @@ def _create_file(fp: Path): assert last_checkpoint(temp_dir / "two_checkpoints") == temp_dir / "two_checkpoints" / "second" / "model1.ckpt" + symlink_path = Path(temp_dir / "two_checkpoints" / "first" / "symlink") + symlink_path.symlink_to(temp_dir / "two_checkpoints" / "second" / "model1.ckpt") + + assert ( + last_checkpoint(temp_dir / "two_checkpoints" / "first") + == temp_dir / "two_checkpoints" / "second" / "model1.ckpt" + ) + _create_file(temp_dir / "zero_checkpoints" / "first" / "not_ckpt") time.sleep(0.1) _create_file(temp_dir / "zero_checkpoints" / "second" / "not_ckpt") diff --git a/thunder/torch/utils.py b/thunder/torch/utils.py index a518dac..6049401 100644 --- a/thunder/torch/utils.py +++ b/thunder/torch/utils.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Any, Union +from typing import Any, Optional, Union import numpy as np import torch @@ -84,26 +84,32 @@ def maybe_from_np(*x: Any, device: Union[torch.device, str] = "cpu") -> Any: >>> x, y, z = maybe_from_np(x, y, z) # maybe_from_np converts np arrays and tensors and does not affect other types >>> dict_of_tensors = to_np(dict_of_np) # maybe_from_np converts any collection """ + def to_tensor(x): if isinstance(x, torch.Tensor): return x.to(device) return torch.from_numpy(x).to(device) + return squeeze_first(apply_to_collection(x, (np.ndarray, np.generic, torch.Tensor), to_tensor)) -def last_checkpoint(root: Union[Path, str]) -> Union[Path, str]: +def last_checkpoint(root: Union[Path, str]) -> Optional[Union[Path, str]]: """ Load most fresh last.ckpt file based on time. Parameters ---------- root: Union[Path, str] - Path to folder, where last.ckpt supposed to be. + Path to folder, where last.ckpt or its symbolic link supposed to be. Returns ------- checkpoint_path: Union[Path, str] If last.ckpt exists - returns Path to it. Otherwise, returns 'last'. """ - checkpoints = [p for p in Path(root).glob("**/*.ckpt") if p.name != "last.ckpt"] - if not checkpoints: - return "last" - return max(checkpoints, key=lambda t: os.stat(t).st_mtime) + checkpoints = [] + for p in Path(root).rglob("*"): + if p.is_symlink(): + p = p.resolve(strict=False) + if p.suffix == ".ckpt": + checkpoints.append(p) + + return max(checkpoints, key=lambda t: os.stat(t).st_mtime, default="last")