Skip to content

Commit 9a46838

Browse files
kshitij12345SiddhantSadangi
authored andcommitted
chore: Update tests
2 parents c4cee0a + 6e00e23 commit 9a46838

File tree

5 files changed

+28
-28
lines changed

5 files changed

+28
-28
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
strategy:
2424
matrix:
2525
os: [ubuntu-latest, macos-latest, windows-latest]
26-
python-version: [3.9]
26+
python-version: [3.9, 3.13]
2727
steps:
2828
- uses: actions/checkout@v2
2929

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v4.3.0
3+
rev: v5.0.0
44
hooks:
55
- id: check-yaml
66
- id: end-of-file-fixer
77
- id: trailing-whitespace
88
- repo: https://github.com/pycqa/isort
9-
rev: 5.12.0
9+
rev: 6.0.1
1010
hooks:
1111
- id: isort
1212
args: [--settings-path, pyproject.toml]
1313
- repo: https://github.com/psf/black
14-
rev: 22.6.0
14+
rev: 25.1.0
1515
hooks:
1616
- id: black
1717
args: [--config, pyproject.toml]
1818
- repo: https://github.com/pycqa/flake8
19-
rev: 5.0.4
19+
rev: 6.0.0
2020
hooks:
2121
- id: flake8
2222
entry: pflake8

CHANGELOG.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
## neptune-pytorch 1.1.0
1+
## neptune-pytorch 2.0.0
2+
3+
### Changes
4+
- Rename `save_model` to `log_model` and `save_checkpoint` to `log_checkpoint`. (https://github.com/neptune-ai/neptune-pytorch/pull/9)
5+
- Prefix private methods with underscore. (https://github.com/neptune-ai/neptune-pytorch/pull/12)
6+
- Add docstrings for `log_model` and `log_checkpoint`. (https://github.com/neptune-ai/neptune-pytorch/pull/11)
7+
8+
9+
## neptune-pytorch 1.1.0 (YANKED)
210

311
### Fixes
412
- Rename `save_model` to `log_model` and `save_checkpoint` to `log_checkpoint`. (https://github.com/neptune-ai/neptune-pytorch/pull/9)

pyproject.toml

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ style = "semver"
99
pattern = "default-unprefixed"
1010

1111
[tool.poetry.dependencies]
12-
python = "^3.7"
12+
python = "^3.9"
1313

1414
# Python lack of functionalities from future versions
1515
importlib-metadata = { version = "*", python = "<3.8" }
16+
setuptools = { version = "*", python = "<3.8" }
1617

1718
torch = ">1.8.0"
1819

@@ -25,28 +26,21 @@ neptune = { version = ">=1.0.0", optional = true }
2526
torchviz = { version = "*", optional = true }
2627

2728
[tool.poetry.extras]
28-
dev = [
29-
"pre-commit",
30-
"pytest",
31-
"pytest-cov",
32-
"pydot",
33-
"neptune",
34-
"torchviz",
35-
]
29+
dev = ["pre-commit", "pytest", "pytest-cov", "pydot", "neptune", "torchviz"]
3630

3731
[tool.poetry]
3832
authors = ["neptune.ai <[email protected]>"]
3933
description = "Neptune.ai pytorch integration library"
4034
repository = "https://github.com/neptune-ai/neptune-pytorch"
4135
homepage = "https://neptune.ai/"
42-
documentation = "https://docs.neptune.ai/integrations-and-supported-tools/model-training/pytorch"
36+
documentation = "https://docs-beta.neptune.ai"
4337
include = ["CHANGELOG.md"]
4438
license = "Apache License 2.0"
4539
name = "neptune-pytorch"
4640
readme = "README.md"
4741
version = "0.0.0"
4842
classifiers = [
49-
"Development Status :: 4 - Beta",
43+
"Development Status :: 5 - Production/Stable",
5044
"Environment :: Console",
5145
"Intended Audience :: Developers",
5246
"Intended Audience :: Science/Research",
@@ -67,17 +61,15 @@ keywords = [
6761
"ML Model Store",
6862
"ML Metadata Store",
6963
]
70-
packages = [
71-
{ include = "neptune_pytorch", from = "src" },
72-
]
64+
packages = [{ include = "neptune_pytorch", from = "src" }]
7365

7466
[tool.poetry.urls]
7567
"Tracker" = "https://github.com/neptune-ai/neptune-pytorch/issues"
76-
"Documentation" = "https://docs.neptune.ai/integrations-and-supported-tools/model-training/pytorch"
68+
"Documentation" = "https://docs-beta.neptune.ai"
7769

7870
[tool.black]
7971
line-length = 120
80-
target-version = ['py37', 'py38', 'py39', 'py310']
72+
target-version = ['py39', 'py310', 'py311', 'py312', 'py313']
8173
include = '\.pyi?$'
8274
exclude = '''
8375
/(

src/neptune_pytorch/impl/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,19 +113,19 @@ def __init__(
113113
self._vis_hook_handler = None
114114
if log_model_diagram:
115115
self.run[self._base_namespace]["model"]["summary"] = str(model)
116-
self.add_visualization_hook()
116+
self._add_visualization_hook()
117117

118118
self.log_gradients = log_gradients
119119
self._gradients_iter_tracker = {}
120120
self._gradients_hook_handler = {}
121121
if self.log_gradients:
122-
self.add_hooks_for_grads()
122+
self._add_hooks_for_grads()
123123

124124
self.log_parameters = log_parameters
125125
self._params_iter_tracker = 0
126126
self._params_hook_handler = None
127127
if self.log_parameters:
128-
self.add_hooks_for_params()
128+
self._add_hooks_for_params()
129129

130130
# Log integration version
131131
root_obj = self.run
@@ -134,7 +134,7 @@ def __init__(
134134

135135
root_obj[INTEGRATION_VERSION_KEY] = __version__
136136

137-
def add_hooks_for_grads(self):
137+
def _add_hooks_for_grads(self):
138138
for name, parameter in self.model.named_parameters():
139139
self._gradients_iter_tracker[name] = 0
140140

@@ -145,7 +145,7 @@ def hook(grad, name=name):
145145

146146
self._gradients_hook_handler[name] = parameter.register_hook(hook)
147147

148-
def add_visualization_hook(self):
148+
def _add_visualization_hook(self):
149149
if not IS_TORCHVIZ_AVAILABLE:
150150
msg = "Skipping model visualization because no torchviz installation was found."
151151
warnings.warn(msg)
@@ -170,7 +170,7 @@ def hook(module, input, output):
170170

171171
self._vis_hook_handler = self.model.register_forward_hook(hook)
172172

173-
def add_hooks_for_params(self):
173+
def _add_hooks_for_params(self):
174174
def hook(module, inp, output):
175175
self._params_iter_tracker += 1
176176
if self._params_iter_tracker % self.log_freq == 0:

0 commit comments

Comments
 (0)