Skip to content

Commit 5ed8241

Browse files
v3.0.0 (feat): Address review comments
1 parent 2894c99 commit 5ed8241

File tree

4 files changed

+39
-30
lines changed

4 files changed

+39
-30
lines changed

README.md

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
[![neptune_scale](https://img.shields.io/badge/neptune__scale-0.14.0+-orange.svg)](https://pypi.org/project/neptune-scale/)
77
[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)
88
[![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)](https://opensource.org/licenses/Apache-2.0)
9-
</div>
109

10+
</div>
1111

1212
The **Neptune-PyTorch** integration simplifies tracking your PyTorch experiments with Neptune by providing automated tracking of PyTorch model internals including activations, gradients, and parameters.
1313

@@ -87,6 +87,7 @@ for epoch in range(num_epochs):
8787
```
8888

8989
**Logged data in Neptune:**
90+
9091
- **Model architecture**: Visual diagram and summary of the neural network
9192
- **Training metrics**: Loss curves and epoch progress
9293
- **Layer activations**: Mean, std, norm, histograms for each layer
@@ -176,6 +177,7 @@ for epoch in range(num_epochs):
176177
```
177178

178179
**Features demonstrated:**
180+
179181
- **Layer filtering**: Only track Conv2d and Linear layers (reduces overhead)
180182
- **Custom statistics**: Use mean, std, hist instead of all 8 statistics
181183
- **Phase-specific tracking**: Different tracking strategies for train/validation
@@ -184,23 +186,27 @@ for epoch in range(num_epochs):
184186
## Features
185187

186188
### Model monitoring
189+
187190
- **Layer activations**: Track activation patterns across all layers with 8 different statistics
188191
- **Gradient analysis**: Monitor gradient flow and detect vanishing/exploding gradients
189192
- **Parameter tracking**: Log parameter statistics and distributions for model analysis
190193
- **Custom statistics**: Choose from mean, std, norm, min, max, var, abs_mean, and hist
191194

192195
### Configuration options
196+
193197
- **Layer filtering**: Track only specific layer types (Conv2d, Linear, etc.)
194198
- **Phase organization**: Separate tracking for training/validation phases with custom prefixes
195199
- **Custom namespaces**: Organize experiments with custom folder structures
196200

197201
### Visualizations
202+
198203
- **Model architecture**: Automatic model diagram generation with torchviz
199204
- **Distribution histograms**: 50-bin histograms for all tracked metrics
200205
- **Real-time monitoring**: Live tracking during training with Neptune
201206
- **Comparative analysis**: Easy comparison across experiments and runs
202207

203208
### Integration
209+
204210
- **Minimal setup**: Simple integration with existing code
205211
- **PyTorch native**: Works with existing PyTorch workflows
206212

@@ -263,20 +269,23 @@ The integration organizes all logged data under a clear hierarchical and customi
263269
**Example namespaces:**
264270

265271
With `base_namespace="my_experiment"`:
272+
266273
- `my_experiment/batch/loss` - Training loss
267274
- `my_experiment/model/summary` - Model architecture
268275
- `my_experiment/model/internals/activations/conv/1/mean` - Mean activation (no prefix)
269276
- `my_experiment/model/internals/train/activations/conv/1/mean` - Mean activation (with "train" prefix)
270277
- `my_experiment/model/internals/validation/gradients/linear1/norm` - L2 norm of gradients (with "validation" prefix)
271278

272279
With `base_namespace=None`:
280+
273281
- `batch/loss` - Training loss
274282
- `model/summary` - Model architecture
275283
- `model/internals/activations/conv/1/mean` - Mean activation (no prefix)
276284
- `model/internals/train/activations/conv/1/mean` - Mean activation (with "train" prefix)
277285
- `model/internals/validation/gradients/linear1/norm` - L2 norm of gradients (with "validation" prefix)
278286

279287
**Layer name handling:**
288+
280289
- Dots in layer names are automatically replaced with forward slashes for proper namespace organization
281290
- Example: `seq_model.0.weight` becomes `seq_model/0/weight` in the namespace
282291
- Example: `module.submodule.layer` becomes `module/submodule/layer` in the namespace
@@ -293,12 +302,13 @@ NeptuneLogger(
293302
model: torch.nn.Module,
294303
base_namespace: Optional[str] = None,
295304
track_layers: Optional[List[Type[nn.Module]]] = None,
296-
tensor_stats: Optional[List[str]] = None,
305+
tensor_stats: Optional[List[TensorStatType]] = None,
297306
log_model_diagram: bool = False
298307
)
299308
```
300309

301310
**Parameters:**
311+
302312
- `run`: Neptune run object for logging
303313
- `model`: PyTorch model to track
304314
- `base_namespace`: Optional top-level folder for organization (default: `None`)
@@ -319,6 +329,7 @@ log_model_internals(
319329
```
320330

321331
**Parameters:**
332+
322333
- `step`: Current training step for logging
323334
- `track_activations`: Track layer activations (default: `True`)
324335
- `track_gradients`: Track layer gradients (default: `True`)
@@ -327,16 +338,16 @@ log_model_internals(
327338

328339
### Available statistics
329340

330-
| Statistic | Description | Use Case |
331-
|-----------|-------------|----------|
332-
| `mean` | Mean value | Monitor activation levels |
333-
| `std` | Standard deviation | Detect activation variance |
334-
| `norm` | L2 norm | Monitor gradient/activation magnitude |
335-
| `min` | Minimum value | Detect dead neurons |
336-
| `max` | Maximum value | Detect saturation |
337-
| `var` | Variance | Monitor activation spread |
338-
| `abs_mean` | Mean of absolute values | Monitor activation strength |
339-
| `hist` | 50-bin histogram | Visualize distributions |
341+
| Statistic | Description | Use Case |
342+
| ---------- | ----------------------- | ------------------------------------- |
343+
| `mean` | Mean value | Monitor activation levels |
344+
| `std` | Standard deviation | Detect activation variance |
345+
| `norm` | L2 norm | Monitor gradient/activation magnitude |
346+
| `min` | Minimum value | Detect dead neurons |
347+
| `max` | Maximum value | Detect saturation |
348+
| `var` | Variance | Monitor activation spread |
349+
| `abs_mean` | Mean of absolute values | Monitor activation strength |
350+
| `hist` | 50-bin histogram | Visualize distributions |
340351

341352
### Namespace structure
342353

@@ -360,16 +371,19 @@ log_model_internals(
360371
Contributions to neptune-pytorch are welcome. Here's how you can help:
361372

362373
### Report issues
374+
363375
- Found a bug? [Open an issue](https://github.com/neptune-ai/neptune-pytorch/issues)
364376
- Include Python version, PyTorch version, and error traceback
365377
- Provide a minimal reproducible example
366378

367379
### Suggest features
380+
368381
- Have an idea? [Create a feature request](https://github.com/neptune-ai/neptune-pytorch/issues)
369382
- Describe the use case and expected behavior
370383
- Check existing issues first to avoid duplicates
371384

372385
### Contribute code
386+
373387
1. Fork the repository
374388
2. Create a feature branch: `git checkout -b feature/amazing-feature`
375389
3. Make your changes and add tests
@@ -381,11 +395,13 @@ Contributions to neptune-pytorch are welcome. Here's how you can help:
381395
## Support
382396

383397
### Get help
398+
384399
- 📖 **Documentation**: [Neptune PyTorch Docs](https://docs.neptune.ai/integrations/pytorch/)
385400
- 🔧 **Troubleshooting**: [Common Issues Guide](https://docs.neptune.ai/troubleshooting)
386401
- 🎫 **Support Portal**: [Reach out to us](https://supportneptune.ai)
387402

388403
### Resources
404+
389405
- [Neptune Documentation](https://docs.neptune.ai/)
390406
- [PyTorch Documentation](https://pytorch.org/docs/)
391407
- [Neptune Examples](https://github.com/neptune-ai/scale-examples)

src/neptune_pytorch/impl/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import warnings
2121
from typing import (
2222
List,
23-
Literal,
2423
Optional,
2524
Type,
2625
)
@@ -30,7 +29,7 @@
3029
from neptune_scale import Run
3130

3231
from neptune_pytorch.impl._torchwatcher import (
33-
TENSOR_STATS,
32+
TensorStatType,
3433
_TorchWatcher,
3534
)
3635
from neptune_pytorch.impl.version import __version__
@@ -121,7 +120,7 @@ def __init__(
121120
base_namespace: Optional[str] = None,
122121
log_model_diagram: bool = False,
123122
track_layers: Optional[List[Type[nn.Module]]] = None,
124-
tensor_stats: Optional[List[Literal[tuple(TENSOR_STATS.keys())]]] = None,
123+
tensor_stats: Optional[List[TensorStatType]] = None,
125124
):
126125
if not isinstance(run, Run):
127126
raise ValueError("run must be a Neptune Run object")

src/neptune_pytorch/impl/_torchwatcher.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
"abs_mean": lambda x: x.abs().mean().item(),
2828
"hist": lambda x: torch.histogram(x, bins=50),
2929
}
30+
# Create a proper type for tensor statistics
31+
TensorStatType = Literal["mean", "std", "norm", "min", "max", "var", "abs_mean", "hist"]
3032

3133

3234
class _HookManager:
@@ -163,7 +165,7 @@ def __init__(
163165
run: Any,
164166
base_namespace: str,
165167
track_layers: Optional[List[Type[nn.Module]]] = None,
166-
tensor_stats: Optional[List[Literal[tuple(TENSOR_STATS.keys())]]] = None,
168+
tensor_stats: Optional[List[TensorStatType]] = None,
167169
) -> None:
168170
"""
169171
Initialize TorchWatcher with configuration options.
@@ -352,11 +354,3 @@ def watch(
352354

353355
# Clear hooks and cached data
354356
self.hm.clear()
355-
356-
def __enter__(self):
357-
"""Context manager entry."""
358-
return self
359-
360-
def __exit__(self, exc_type, exc_val, exc_tb):
361-
"""Context manager exit - cleanup hooks."""
362-
self.hm.remove_hooks()

tests/test_torchwatcher.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -335,13 +335,13 @@ def test_histogram_processing(self, mock_run, test_model, test_data):
335335
assert hasattr(hist_data, "bin_edges"), "Histogram should have bin_edges"
336336
assert hasattr(hist_data, "counts"), "Histogram should have counts"
337337

338-
def test_context_manager(self, mock_run, test_model):
339-
"""Test TorchWatcher as context manager."""
340-
with _TorchWatcher(model=test_model, run=mock_run, base_namespace="test") as tw:
341-
assert tw is not None
342-
assert len(tw.hm.hooks) > 0
338+
def test_hook_cleanup_on_destruction(self, mock_run, test_model):
339+
"""Test that hooks are cleaned up when TorchWatcher is destroyed."""
340+
tw = _TorchWatcher(model=test_model, run=mock_run, base_namespace="test")
341+
assert len(tw.hm.hooks) > 0
343342

344-
# Hooks should be removed after context exit
343+
# Manually call remove_hooks to test cleanup
344+
tw.hm.remove_hooks()
345345
assert len(tw.hm.hooks) == 0
346346

347347
def test_safe_tensor_stats(self, mock_run, test_model):

0 commit comments

Comments
 (0)