From facc10b9bcb94d92489e07a48f24f24dd816073f Mon Sep 17 00:00:00 2001 From: Kirill Poliakov Date: Mon, 26 Feb 2024 11:30:54 +0300 Subject: [PATCH 1/8] Fix last_checkpoint() --- thunder/torch/utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/thunder/torch/utils.py b/thunder/torch/utils.py index a518dac..ef9f6a5 100644 --- a/thunder/torch/utils.py +++ b/thunder/torch/utils.py @@ -97,13 +97,21 @@ def last_checkpoint(root: Union[Path, str]) -> Union[Path, str]: Parameters ---------- root: Union[Path, str] - Path to folder, where last.ckpt supposed to be. + Path to folder, where last.ckpt or its symbolic link supposed to be. Returns ------- checkpoint_path: Union[Path, str] If last.ckpt exists - returns Path to it. Otherwise, returns 'last'. """ - checkpoints = [p for p in Path(root).glob("**/*.ckpt") if p.name != "last.ckpt"] + checkpoints = [] + for p in Path(root).rglob('*'): + if p.is_symlink(): + p = p.resolve(strict=False) + if p.suffix == '.ckpt': + checkpoints.append(p) + elif p.suffix == '.ckpt': + checkpoints.append(p) + if not checkpoints: - return "last" + return None return max(checkpoints, key=lambda t: os.stat(t).st_mtime) From ac7e59f469c472b53be25b573e9605c40c85a457 Mon Sep 17 00:00:00 2001 From: Kirill Poliakov Date: Mon, 26 Feb 2024 11:35:41 +0300 Subject: [PATCH 2/8] Update type annotations for last_checkpoint --- thunder/torch/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/thunder/torch/utils.py b/thunder/torch/utils.py index ef9f6a5..5a25b18 100644 --- a/thunder/torch/utils.py +++ b/thunder/torch/utils.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Any, Union +from typing import Any, Union, Optional import numpy as np import torch @@ -84,14 +84,16 @@ def maybe_from_np(*x: Any, device: Union[torch.device, str] = "cpu") -> Any: >>> x, y, z = maybe_from_np(x, y, z) # maybe_from_np converts np arrays and tensors and does not affect other types >>> dict_of_tensors = to_np(dict_of_np) # maybe_from_np converts any collection """ + def to_tensor(x): if isinstance(x, torch.Tensor): return x.to(device) return torch.from_numpy(x).to(device) + return squeeze_first(apply_to_collection(x, (np.ndarray, np.generic, torch.Tensor), to_tensor)) -def last_checkpoint(root: Union[Path, str]) -> Union[Path, str]: +def last_checkpoint(root: Union[Path, str]) -> Optional[Union[Path, str]]: """ Load most fresh last.ckpt file based on time. Parameters From 3a94c66e4a8bc2eab3c31d3b4c676bf28834b174 Mon Sep 17 00:00:00 2001 From: Kirill Poliakov Date: Mon, 26 Feb 2024 12:54:19 +0300 Subject: [PATCH 3/8] Fix imports --- thunder/torch/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/torch/utils.py b/thunder/torch/utils.py index 5a25b18..2055cb6 100644 --- a/thunder/torch/utils.py +++ b/thunder/torch/utils.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Any, Union, Optional +from typing import Any, Optional, Union import numpy as np import torch From be3cd636c483ccb30ef8a6d96611f6ae5ade258a Mon Sep 17 00:00:00 2001 From: Kirill Poliakov Date: Mon, 26 Feb 2024 13:03:51 +0300 Subject: [PATCH 4/8] Update test_last_checkpoint --- tests/torch/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/torch/test_utils.py b/tests/torch/test_utils.py index ea46482..f115f4e 100644 --- a/tests/torch/test_utils.py +++ b/tests/torch/test_utils.py @@ -94,4 +94,4 @@ def _create_file(fp: Path): time.sleep(0.1) _create_file(temp_dir / "zero_checkpoints" / "second" / "not_ckpt") - assert last_checkpoint(temp_dir / "zero_checkpoints") == "last" + assert last_checkpoint(temp_dir / "zero_checkpoints") == None From a758ab7a32d73fc55cc9e4ca13832958b83812ba Mon Sep 17 00:00:00 2001 From: Kirill Poliakov Date: Mon, 26 Feb 2024 13:05:42 +0300 Subject: [PATCH 5/8] Fix dumb mistake --- tests/torch/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/torch/test_utils.py b/tests/torch/test_utils.py index f115f4e..80180aa 100644 --- a/tests/torch/test_utils.py +++ b/tests/torch/test_utils.py @@ -94,4 +94,4 @@ def _create_file(fp: Path): time.sleep(0.1) _create_file(temp_dir / "zero_checkpoints" / "second" / "not_ckpt") - assert last_checkpoint(temp_dir / "zero_checkpoints") == None + assert last_checkpoint(temp_dir / "zero_checkpoints") is None From bc1d21bc22dbc44be2c59256dbee74c990365fef Mon Sep 17 00:00:00 2001 From: kirillp Date: Thu, 21 Mar 2024 20:36:41 +0000 Subject: [PATCH 6/8] Fixes according to reviews --- thunder/torch/utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/thunder/torch/utils.py b/thunder/torch/utils.py index 2055cb6..5825c56 100644 --- a/thunder/torch/utils.py +++ b/thunder/torch/utils.py @@ -106,14 +106,12 @@ def last_checkpoint(root: Union[Path, str]) -> Optional[Union[Path, str]]: If last.ckpt exists - returns Path to it. Otherwise, returns 'last'. """ checkpoints = [] - for p in Path(root).rglob('*'): + for p in Path(root).rglob("*"): if p.is_symlink(): p = p.resolve(strict=False) - if p.suffix == '.ckpt': - checkpoints.append(p) - elif p.suffix == '.ckpt': + if p.suffix == ".ckpt": checkpoints.append(p) if not checkpoints: - return None + return "last" return max(checkpoints, key=lambda t: os.stat(t).st_mtime) From 7560282975f70068234e4e23d1880309c394f383 Mon Sep 17 00:00:00 2001 From: kirillp Date: Thu, 21 Mar 2024 21:04:39 +0000 Subject: [PATCH 7/8] Update test_last_checkpoint --- tests/torch/test_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/torch/test_utils.py b/tests/torch/test_utils.py index 80180aa..114eb18 100644 --- a/tests/torch/test_utils.py +++ b/tests/torch/test_utils.py @@ -90,8 +90,16 @@ def _create_file(fp: Path): assert last_checkpoint(temp_dir / "two_checkpoints") == temp_dir / "two_checkpoints" / "second" / "model1.ckpt" + symlink_path = Path(temp_dir / "two_checkpoints" / "first" / "symlink") + symlink_path.symlink_to(temp_dir / "two_checkpoints" / "second" / "model1.ckpt") + + assert ( + last_checkpoint(temp_dir / "two_checkpoints" / "first") + == temp_dir / "two_checkpoints" / "second" / "model1.ckpt" + ) + _create_file(temp_dir / "zero_checkpoints" / "first" / "not_ckpt") time.sleep(0.1) _create_file(temp_dir / "zero_checkpoints" / "second" / "not_ckpt") - assert last_checkpoint(temp_dir / "zero_checkpoints") is None + assert last_checkpoint(temp_dir / "zero_checkpoints") == "last" From 9f7e1a1b74eb32277c6d4d5ddca2e227f8a490ed Mon Sep 17 00:00:00 2001 From: kirillp Date: Fri, 22 Mar 2024 09:59:51 +0000 Subject: [PATCH 8/8] One more fix to last_checkpoint --- thunder/torch/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/thunder/torch/utils.py b/thunder/torch/utils.py index 5825c56..6049401 100644 --- a/thunder/torch/utils.py +++ b/thunder/torch/utils.py @@ -112,6 +112,4 @@ def last_checkpoint(root: Union[Path, str]) -> Optional[Union[Path, str]]: if p.suffix == ".ckpt": checkpoints.append(p) - if not checkpoints: - return "last" - return max(checkpoints, key=lambda t: os.stat(t).st_mtime) + return max(checkpoints, key=lambda t: os.stat(t).st_mtime, default="last")