Skip to content

Commit f5b0c5c

Browse files
afzal442vfdev-5
andauthored
Added Arguments *args, **kwargs to BaseLogger.attach method (#2034)
* Added Arguments *args, **kwargs * Reformatted * Updated _test method and added default value of kwargs * Updated _test method * fix minor changes * fixed minor changes * Reformatted Co-authored-by: vfdev <[email protected]>
1 parent 9181716 commit f5b0c5c

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

ignite/contrib/handlers/base_logger.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,12 @@ class BaseLogger(metaclass=ABCMeta):
151151
"""
152152

153153
def attach(
154-
self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter, EventsList]
154+
self,
155+
engine: Engine,
156+
log_handler: Callable,
157+
event_name: Union[str, Events, CallableEventWithFilter, EventsList],
158+
*args: Any,
159+
**kwargs: Any,
155160
) -> RemovableEventHandle:
156161
"""Attach the logger to the engine and execute `log_handler` function at `event_name` events.
157162
@@ -161,6 +166,8 @@ def attach(
161166
event_name: event to attach the logging handler to. Valid events are from
162167
:class:`~ignite.engine.events.Events` or :class:`~ignite.engine.events.EventsList` or any `event_name`
163168
added by :meth:`~ignite.engine.engine.Engine.register_events`.
169+
args: args forwarded to the `log_handler` method
170+
kwargs: kwargs forwarded to the `log_handler` method
164171
165172
Returns:
166173
:class:`~ignite.engine.events.RemovableEventHandle`, which can be used to remove the handler.
@@ -178,7 +185,7 @@ def attach(
178185
if event_name not in State.event_to_attr:
179186
raise RuntimeError(f"Unknown event name '{event_name}'")
180187

181-
return engine.add_event_handler(event_name, log_handler, self, event_name)
188+
return engine.add_event_handler(event_name, log_handler, self, event_name, *args, **kwargs)
182189

183190
def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle:
184191
"""Shortcut method to attach `OutputHandler` to the logger.

tests/ignite/contrib/handlers/test_base_logger.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_attach():
103103
n_epochs = 5
104104
data = list(range(50))
105105

106-
def _test(event, n_calls):
106+
def _test(event, n_calls, kwargs={}):
107107

108108
losses = torch.rand(n_epochs * len(data))
109109
losses_iter = iter(losses)
@@ -117,19 +117,24 @@ def update_fn(engine, batch):
117117

118118
mock_log_handler = MagicMock()
119119

120-
logger.attach(trainer, log_handler=mock_log_handler, event_name=event)
120+
logger.attach(trainer, log_handler=mock_log_handler, event_name=event, **kwargs)
121121

122122
trainer.run(data, max_epochs=n_epochs)
123123

124124
if isinstance(event, EventsList):
125125
events = [e for e in event]
126126
else:
127127
events = [event]
128-
calls = [call(trainer, logger, e) for e in events]
128+
129+
if len(kwargs) > 0:
130+
calls = [call(trainer, logger, e, **kwargs) for e in events]
131+
else:
132+
calls = [call(trainer, logger, e) for e in events]
133+
129134
mock_log_handler.assert_has_calls(calls)
130135
assert mock_log_handler.call_count == n_calls
131136

132-
_test(Events.ITERATION_STARTED, len(data) * n_epochs)
137+
_test(Events.ITERATION_STARTED, len(data) * n_epochs, kwargs={"a": 0})
133138
_test(Events.ITERATION_COMPLETED, len(data) * n_epochs)
134139
_test(Events.EPOCH_STARTED, n_epochs)
135140
_test(Events.EPOCH_COMPLETED, n_epochs)

0 commit comments

Comments
 (0)