|
8 | 8 | import task as mod_task
|
9 | 9 | import dataset as mod_dataset
|
10 | 10 |
|
| 11 | +import numpy as np |
11 | 12 | import matplotlib.pyplot as plt
|
12 | 13 |
|
13 | 14 | # Check available backends for mixed precision training
|
@@ -400,21 +401,68 @@ def test_plot_with_skip_and_suggest_lr(suggest_lr, skip_start, skip_end):
|
400 | 401 | )
|
401 | 402 |
|
402 | 403 | fig, ax = plt.subplots()
|
403 |
| - results = lr_finder.plot( |
404 |
| - skip_start=skip_start, skip_end=skip_end, suggest_lr=suggest_lr, ax=ax |
405 |
| - ) |
406 | 404 |
|
407 |
| - if num_iter - skip_start - skip_end <= 1: |
408 |
| - # handle data with one or zero lr |
409 |
| - assert len(ax.lines) == 1 |
410 |
| - assert results is ax |
| 405 | + results = None |
| 406 | + if suggest_lr and num_iter < (skip_start + skip_end + 2): |
| 407 | + # No sufficient data points to calculate gradient, so this call should fail |
| 408 | + with pytest.raises(RuntimeError, match="Need at least"): |
| 409 | + results = lr_finder.plot( |
| 410 | + skip_start=skip_start, skip_end=skip_end, suggest_lr=suggest_lr, ax=ax |
| 411 | + ) |
| 412 | + |
| 413 | + # No need to proceed then |
| 414 | + return |
411 | 415 | else:
|
412 |
| - # handle different suggest_lr |
413 |
| - # for 'steepest': the point with steepest gradient (minimal gradient) |
414 |
| - assert len(ax.lines) == 1 |
415 |
| - assert len(ax.collections) == int(suggest_lr) |
416 |
| - if results is not ax: |
417 |
| - assert len(results) == 2 |
| 416 | + results = lr_finder.plot( |
| 417 | + skip_start=skip_start, skip_end=skip_end, suggest_lr=suggest_lr, ax=ax |
| 418 | + ) |
| 419 | + |
| 420 | + # NOTE: |
| 421 | + # - ax.lines[0]: the lr-loss curve. It should be always available once |
| 422 | + # `ax.plot(lrs, losses)` is called. But when there is no sufficent data |
| 423 | + # point (num_iter <= skip_start + skip_end), the coordinates will be |
| 424 | + # 2 empty arrays. |
| 425 | + # - ax.collections[0]: the point of suggested lr (type: <PathCollection>). |
| 426 | + # It's available only when there are sufficient data points to calculate |
| 427 | + # gradient of lr-loss curve. |
| 428 | + assert len(ax.lines) == 1 |
| 429 | + |
| 430 | + if suggest_lr: |
| 431 | + assert isinstance(results, tuple) and len(results) == 2 |
| 432 | + |
| 433 | + ret_ax, ret_lr = results |
| 434 | + assert ret_ax is ax |
| 435 | + |
| 436 | + # XXX: Currently suggested lr is selected according to gradient of |
| 437 | + # lr-loss curve, so there should be at least 2 valid data points (after |
| 438 | + # filtered by `skip_start` and `skip_end`). If not, the returned lr |
| 439 | + # will be None. |
| 440 | + # But we would need to rework on this if there are more suggestion |
| 441 | + # methods is supported in the future. |
| 442 | + if num_iter - skip_start - skip_end <= 1: |
| 443 | + assert ret_lr is None |
| 444 | + assert len(ax.collections) == 0 |
| 445 | + else: |
| 446 | + assert len(ax.collections) == 1 |
| 447 | + else: |
| 448 | + # Not suggesting lr, so it just plots a lr-loss curve. |
| 449 | + assert results is ax |
| 450 | + assert len(ax.collections) == 0 |
| 451 | + |
| 452 | + # Check whether the data of plotted line is the same as the one filtered |
| 453 | + # according to `skip_start` and `skip_end`. |
| 454 | + lrs = np.array(lr_finder.history["lr"]) |
| 455 | + losses = np.array(lr_finder.history["loss"]) |
| 456 | + x, y = ax.lines[0].get_data() |
| 457 | + |
| 458 | + # If skip_end is 0, we should replace it with None. Otherwise, it |
| 459 | + # will create a slice as `x[0:-0]` which is an empty list. |
| 460 | + _slice = slice(skip_start, -skip_end if skip_end != 0 else None, None) |
| 461 | + assert np.allclose(x, lrs[_slice]) |
| 462 | + assert np.allclose(y, losses[_slice]) |
| 463 | + |
| 464 | + # Close figure to release memory |
| 465 | + plt.close() |
418 | 466 |
|
419 | 467 |
|
420 | 468 | def test_suggest_lr():
|
|
0 commit comments