Skip to content

Commit 5b9d92c

Browse files
authored
Prevent unexpected unpacking error when calling lr_finder.plot() with suggest_lr=True (#98)
* MAINT: always return 2 values when `suggest_lr` is True As it's mentioned in #88, suggested lr would not be returned along with `ax` (`matplotlib.Axes`) if there is no sufficient data points to calculate gradient of lr-loss curve. Though it would warn user about this problem [1], but users might be confused by another error caused by unpacking returned value. This is because users would usually expect it works as below: ```python ax, lr = lr_finder.plot(..., suggest_lr=True) ``` But the second returned value `lr` might not exist when it failed to find a suggested lr, then the returned value would be a single value instead. Therefore, the unpacking syntax `ax, lr = ...` would fail and result in the error reported in #88. So we fix it by always returning both `ax` and `suggested_lr` when the flag `suggest_lr` is True to meet the expectation, and leave the responsibility of "check whether `lr` is null" back to user. [1]: https://github.com/davidtvs/pytorch-lr-finder/blob/fd9e949/torch_lr_finder/lr_finder.py#L539-L542 * MAINT: raise error earlier if there is no sufficient data points to suggest LR Now LR finder will raise a RuntimeError if there is no sufficient data points to calculate gradient for suggested LR when `lr_finder.plot(..., suggest_lr=True)` is called. The error message will clarify the details of failure, so users can fix the issue earlier as well.
1 parent fd9e949 commit 5b9d92c

File tree

2 files changed

+85
-34
lines changed

2 files changed

+85
-34
lines changed

tests/test_lr_finder.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import task as mod_task
99
import dataset as mod_dataset
1010

11+
import numpy as np
1112
import matplotlib.pyplot as plt
1213

1314
# Check available backends for mixed precision training
@@ -400,21 +401,68 @@ def test_plot_with_skip_and_suggest_lr(suggest_lr, skip_start, skip_end):
400401
)
401402

402403
fig, ax = plt.subplots()
403-
results = lr_finder.plot(
404-
skip_start=skip_start, skip_end=skip_end, suggest_lr=suggest_lr, ax=ax
405-
)
406404

407-
if num_iter - skip_start - skip_end <= 1:
408-
# handle data with one or zero lr
409-
assert len(ax.lines) == 1
410-
assert results is ax
405+
results = None
406+
if suggest_lr and num_iter < (skip_start + skip_end + 2):
407+
# No sufficient data points to calculate gradient, so this call should fail
408+
with pytest.raises(RuntimeError, match="Need at least"):
409+
results = lr_finder.plot(
410+
skip_start=skip_start, skip_end=skip_end, suggest_lr=suggest_lr, ax=ax
411+
)
412+
413+
# No need to proceed then
414+
return
411415
else:
412-
# handle different suggest_lr
413-
# for 'steepest': the point with steepest gradient (minimal gradient)
414-
assert len(ax.lines) == 1
415-
assert len(ax.collections) == int(suggest_lr)
416-
if results is not ax:
417-
assert len(results) == 2
416+
results = lr_finder.plot(
417+
skip_start=skip_start, skip_end=skip_end, suggest_lr=suggest_lr, ax=ax
418+
)
419+
420+
# NOTE:
421+
# - ax.lines[0]: the lr-loss curve. It should be always available once
422+
# `ax.plot(lrs, losses)` is called. But when there is no sufficent data
423+
# point (num_iter <= skip_start + skip_end), the coordinates will be
424+
# 2 empty arrays.
425+
# - ax.collections[0]: the point of suggested lr (type: <PathCollection>).
426+
# It's available only when there are sufficient data points to calculate
427+
# gradient of lr-loss curve.
428+
assert len(ax.lines) == 1
429+
430+
if suggest_lr:
431+
assert isinstance(results, tuple) and len(results) == 2
432+
433+
ret_ax, ret_lr = results
434+
assert ret_ax is ax
435+
436+
# XXX: Currently suggested lr is selected according to gradient of
437+
# lr-loss curve, so there should be at least 2 valid data points (after
438+
# filtered by `skip_start` and `skip_end`). If not, the returned lr
439+
# will be None.
440+
# But we would need to rework on this if there are more suggestion
441+
# methods is supported in the future.
442+
if num_iter - skip_start - skip_end <= 1:
443+
assert ret_lr is None
444+
assert len(ax.collections) == 0
445+
else:
446+
assert len(ax.collections) == 1
447+
else:
448+
# Not suggesting lr, so it just plots a lr-loss curve.
449+
assert results is ax
450+
assert len(ax.collections) == 0
451+
452+
# Check whether the data of plotted line is the same as the one filtered
453+
# according to `skip_start` and `skip_end`.
454+
lrs = np.array(lr_finder.history["lr"])
455+
losses = np.array(lr_finder.history["loss"])
456+
x, y = ax.lines[0].get_data()
457+
458+
# If skip_end is 0, we should replace it with None. Otherwise, it
459+
# will create a slice as `x[0:-0]` which is an empty list.
460+
_slice = slice(skip_start, -skip_end if skip_end != 0 else None, None)
461+
assert np.allclose(x, lrs[_slice])
462+
assert np.allclose(y, losses[_slice])
463+
464+
# Close figure to release memory
465+
plt.close()
418466

419467

420468
def test_suggest_lr():

torch_lr_finder/lr_finder.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,13 @@ def plot(
510510
if show_lr is not None and not isinstance(show_lr, float):
511511
raise ValueError("show_lr must be float")
512512

513+
# Make sure there are enough data points to suggest a learning rate
514+
if suggest_lr and len(self.history["lr"]) < (skip_start + skip_end + 2):
515+
raise RuntimeError(
516+
f"Need at least {skip_start + skip_end + 2} iterations to suggest a "
517+
f"learning rate. Got {len(self.history['lr'])}"
518+
)
519+
513520
# Get the data to plot from the history dictionary. Also, handle skip_end=0
514521
# properly so the behaviour is the expected
515522
lrs = self.history["lr"]
@@ -533,25 +540,19 @@ def plot(
533540
if suggest_lr:
534541
# 'steepest': the point with steepest gradient (minimal gradient)
535542
print("LR suggestion: steepest gradient")
536-
min_grad_idx = None
537-
try:
538-
min_grad_idx = (np.gradient(np.array(losses))).argmin()
539-
except ValueError:
540-
print(
541-
"Failed to compute the gradients, there might not be enough points."
542-
)
543-
if min_grad_idx is not None:
544-
print("Suggested LR: {:.2E}".format(lrs[min_grad_idx]))
545-
ax.scatter(
546-
lrs[min_grad_idx],
547-
losses[min_grad_idx],
548-
s=75,
549-
marker="o",
550-
color="red",
551-
zorder=3,
552-
label="steepest gradient",
553-
)
554-
ax.legend()
543+
min_grad_idx = (np.gradient(np.array(losses))).argmin()
544+
545+
print("Suggested LR: {:.2E}".format(lrs[min_grad_idx]))
546+
ax.scatter(
547+
lrs[min_grad_idx],
548+
losses[min_grad_idx],
549+
s=75,
550+
marker="o",
551+
color="red",
552+
zorder=3,
553+
label="steepest gradient",
554+
)
555+
ax.legend()
555556

556557
if log_lr:
557558
ax.set_xscale("log")
@@ -565,8 +566,10 @@ def plot(
565566
if fig is not None:
566567
plt.show()
567568

568-
if suggest_lr and min_grad_idx is not None:
569-
return ax, lrs[min_grad_idx]
569+
if suggest_lr:
570+
# If suggest_lr is set, then we should always return 2 values.
571+
suggest_lr = lrs[min_grad_idx]
572+
return ax, suggest_lr
570573
else:
571574
return ax
572575

0 commit comments

Comments
 (0)