Skip to content

Commit

Permalink
Merge pull request #84 from kirpoly/dev
Browse files Browse the repository at this point in the history
Fix `last_checkpoint`
  • Loading branch information
arseniybelkov authored Mar 22, 2024
2 parents f5a8e0b + 069ab05 commit bbf8bbb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
8 changes: 8 additions & 0 deletions tests/torch/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def _create_file(fp: Path):

assert last_checkpoint(temp_dir / "two_checkpoints") == temp_dir / "two_checkpoints" / "second" / "model1.ckpt"

symlink_path = Path(temp_dir / "two_checkpoints" / "first" / "symlink")
symlink_path.symlink_to(temp_dir / "two_checkpoints" / "second" / "model1.ckpt")

assert (
last_checkpoint(temp_dir / "two_checkpoints" / "first")
== temp_dir / "two_checkpoints" / "second" / "model1.ckpt"
)

_create_file(temp_dir / "zero_checkpoints" / "first" / "not_ckpt")
time.sleep(0.1)
_create_file(temp_dir / "zero_checkpoints" / "second" / "not_ckpt")
Expand Down
20 changes: 13 additions & 7 deletions thunder/torch/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Any, Union
from typing import Any, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -84,26 +84,32 @@ def maybe_from_np(*x: Any, device: Union[torch.device, str] = "cpu") -> Any:
>>> x, y, z = maybe_from_np(x, y, z) # maybe_from_np converts np arrays and tensors and does not affect other types
>>> dict_of_tensors = to_np(dict_of_np) # maybe_from_np converts any collection
"""

def to_tensor(x):
if isinstance(x, torch.Tensor):
return x.to(device)
return torch.from_numpy(x).to(device)

return squeeze_first(apply_to_collection(x, (np.ndarray, np.generic, torch.Tensor), to_tensor))


def last_checkpoint(root: Union[Path, str]) -> Union[Path, str]:
def last_checkpoint(root: Union[Path, str]) -> Optional[Union[Path, str]]:
"""
Load most fresh last.ckpt file based on time.
Parameters
----------
root: Union[Path, str]
Path to folder, where last.ckpt supposed to be.
Path to folder, where last.ckpt or its symbolic link supposed to be.
Returns
-------
checkpoint_path: Union[Path, str]
If last.ckpt exists - returns Path to it. Otherwise, returns 'last'.
"""
checkpoints = [p for p in Path(root).glob("**/*.ckpt") if p.name != "last.ckpt"]
if not checkpoints:
return "last"
return max(checkpoints, key=lambda t: os.stat(t).st_mtime)
checkpoints = []
for p in Path(root).rglob("*"):
if p.is_symlink():
p = p.resolve(strict=False)
if p.suffix == ".ckpt":
checkpoints.append(p)

return max(checkpoints, key=lambda t: os.stat(t).st_mtime, default="last")

0 comments on commit bbf8bbb

Please sign in to comment.