@@ -113,19 +113,19 @@ def __init__(
113113 self ._vis_hook_handler = None
114114 if log_model_diagram :
115115 self .run [self ._base_namespace ]["model" ]["summary" ] = str (model )
116- self .add_visualization_hook ()
116+ self ._add_visualization_hook ()
117117
118118 self .log_gradients = log_gradients
119119 self ._gradients_iter_tracker = {}
120120 self ._gradients_hook_handler = {}
121121 if self .log_gradients :
122- self .add_hooks_for_grads ()
122+ self ._add_hooks_for_grads ()
123123
124124 self .log_parameters = log_parameters
125125 self ._params_iter_tracker = 0
126126 self ._params_hook_handler = None
127127 if self .log_parameters :
128- self .add_hooks_for_params ()
128+ self ._add_hooks_for_params ()
129129
130130 # Log integration version
131131 root_obj = self .run
@@ -134,7 +134,7 @@ def __init__(
134134
135135 root_obj [INTEGRATION_VERSION_KEY ] = __version__
136136
137- def add_hooks_for_grads (self ):
137+ def _add_hooks_for_grads (self ):
138138 for name , parameter in self .model .named_parameters ():
139139 self ._gradients_iter_tracker [name ] = 0
140140
@@ -145,7 +145,7 @@ def hook(grad, name=name):
145145
146146 self ._gradients_hook_handler [name ] = parameter .register_hook (hook )
147147
148- def add_visualization_hook (self ):
148+ def _add_visualization_hook (self ):
149149 if not IS_TORCHVIZ_AVAILABLE :
150150 msg = "Skipping model visualization because no torchviz installation was found."
151151 warnings .warn (msg )
@@ -170,7 +170,7 @@ def hook(module, input, output):
170170
171171 self ._vis_hook_handler = self .model .register_forward_hook (hook )
172172
173- def add_hooks_for_params (self ):
173+ def _add_hooks_for_params (self ):
174174 def hook (module , inp , output ):
175175 self ._params_iter_tracker += 1
176176 if self ._params_iter_tracker % self .log_freq == 0 :
0 commit comments