Skip to content

Commit e51df64

Browse files
authored
Merge pull request #6 from neptune-ai/handle-torchviz
handle if dot is not present
2 parents 4e69f16 + 01e1876 commit e51df64

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
## neptune-pytorch 0.2.0
2+
3+
### Fixes
4+
- Change where `checkpoints` are logged. Previously they we logged under `base_namespace/model` but now they will be logged under `base_namespace/model/checkpoints` (https://github.com/neptune-ai/neptune-pytorch/pull/5)
5+
- Add warning if `dot` is not installed instead of hard error. Also, improve clean-up of visualization files (https://github.com/neptune-ai/neptune-pytorch/pull/6)
6+
7+
18
## neptune-pytorch 0.1.0 (initial release)
29

310
### Features

src/neptune_pytorch/impl/__init__.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
__all__ = ["__version__", "NeptuneLogger"]
1717

1818
import os
19+
import uuid
1920
import warnings
2021
import weakref
2122
from typing import (
@@ -42,6 +43,7 @@
4243
IS_TORCHVIZ_AVAILABLE = True
4344
try:
4445
import torchviz
46+
from graphviz import ExecutableNotFound
4547
except ImportError:
4648
IS_TORCHVIZ_AVAILABLE = False
4749

@@ -152,10 +154,18 @@ def add_visualization_hook(self):
152154
def hook(module, input, output):
153155
if not self._is_viz_saved:
154156
dot = torchviz.make_dot(output, params=dict(module.named_parameters()))
155-
# Use tempfile correctly.
156157
dot.format = "png"
157-
dot.render(outfile="torch-viz.png")
158-
self._namespace_handler["model"]["visualization"].upload("torch-viz.png")
158+
# generate unique name so that multiple concurrent runs
159+
# don't over-write each other.
160+
viz_name = str(uuid.uuid4()) + ".png"
161+
try:
162+
dot.render(outfile=viz_name)
163+
safe_upload_visualization(self._namespace_handler["model"], "visualization", viz_name)
164+
except ExecutableNotFound:
165+
# This errors because `dot` renderer is not found even
166+
# if python binding of `graphviz` are available.
167+
warnings.warn("Skipping model visualization because no dot (graphviz) installation was found.")
168+
159169
self._is_viz_saved = True
160170

161171
self._vis_hook_handler = self.model.register_forward_hook(hook)
@@ -181,7 +191,7 @@ def save_model(self, model_name: Optional[str] = None):
181191
# User is not expected to add extension
182192
model_name = model_name + ".pt"
183193

184-
safe_upload(self._namespace_handler["model"], model_name, self.model)
194+
safe_upload_model(self._namespace_handler["model"], model_name, self.model)
185195

186196
def save_checkpoint(self, checkpoint_name: Optional[str] = None):
187197
if checkpoint_name is None:
@@ -192,7 +202,7 @@ def save_checkpoint(self, checkpoint_name: Optional[str] = None):
192202
# User is not expected to add extension
193203
checkpoint_name = checkpoint_name + ".pt"
194204

195-
safe_upload(self._namespace_handler["model"]["checkpoints"], checkpoint_name, self.model)
205+
safe_upload_model(self._namespace_handler["model"]["checkpoints"], checkpoint_name, self.model)
196206

197207
def __del__(self):
198208
# Remove hooks
@@ -207,7 +217,23 @@ def __del__(self):
207217
self._vis_hook_handler.remove()
208218

209219

210-
def safe_upload(run, name, model):
220+
def safe_upload_visualization(run: Run, name: str, file_name: str):
221+
# Function to safely upload a file and
222+
# delete the file on completion of upload.
223+
# We utilise the weakref.finalize to remove
224+
# the file once the stream object goes out-of-scope.
225+
226+
def remove(file_name):
227+
os.remove(file_name)
228+
# Also remove graphviz intermediate file.
229+
os.remove(file_name.replace(".png", ".gv"))
230+
231+
with open(file_name, "rb") as f:
232+
weakref.finalize(f, remove, file_name)
233+
run[name].upload(File.from_stream(f, extension="png"))
234+
235+
236+
def safe_upload_model(run: Run, name: str, model: torch.nn.Module):
211237
# Function to safely upload a file and
212238
# delete the file on completion of upload.
213239
# We utilise the weakref.finalize to remove

0 commit comments

Comments
 (0)