Skip to content

Commit 884107c

Browse files
authored
Release v0.5 (#476)
1 parent 12ea371 commit 884107c

File tree

7 files changed

+48
-81
lines changed

7 files changed

+48
-81
lines changed

CHANGELOG.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,24 @@
11
# Changelog
22

3+
## v0.5.0 (2022-02-16)
4+
5+
### Enhancements
6+
7+
* Bump Nx dependency
8+
* Update documentation to account for channels last default
9+
* Improve error message in compilation/build errors for models
10+
* Remove deprecated `transform`
11+
12+
### Deprecations
13+
14+
* Deprecate `Axon.Loop.handle/4`
15+
16+
## v0.4.1 (2022-01-21)
17+
18+
### Bug Fixes
19+
20+
* Fixed a shape mismatch when training with certain optimizers
21+
322
## v0.4.0 (2022-01-19)
423

524
### Enhancements

lib/axon/defn.ex

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,7 @@ defmodule Axon.Defn do
1919

2020
@impl true
2121
def __compile__(_, _, _, _), do: raise("not implemented")
22+
23+
@impl true
24+
def __partitions_options__(_), do: raise("not implemented")
2225
end

lib/axon/loop.ex

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,6 @@ defmodule Axon.Loop do
152152
:iteration_completed, # On iteration complete
153153
:epoch_completed, # On epoch complete
154154
:epoch_halted, # On epoch halt, if early halted
155-
:halted, # On loop halt, if early halted
156-
:completed # On loop completion
157155
]
158156
159157
You can attach event handlers to events using `Axon.Loop.handle_event/4`:
@@ -229,9 +227,7 @@ defmodule Axon.Loop do
229227
:iteration_started,
230228
:iteration_completed,
231229
:epoch_completed,
232-
:epoch_halted,
233-
:halted,
234-
:completed
230+
:epoch_halted
235231
]
236232

237233
@default_handlers %{
@@ -896,8 +892,6 @@ defmodule Axon.Loop do
896892
:iteration_completed, # On iteration complete
897893
:epoch_completed, # On epoch complete
898894
:epoch_halted, # On epoch halt, if early halted
899-
:halted, # On loop halt, if early halted
900-
:completed # On loop completion
901895
]
902896
903897
Generally, event handlers are side-effecting operations which provide some
@@ -1066,7 +1060,6 @@ defmodule Axon.Loop do
10661060

10671061
metrics =
10681062
Enum.reduce(metric_fns, evaluator, fn {k, {_, v}}, loop -> metric(loop, v, k) end)
1069-
|> log(fn _ -> "\n" end, event: :completed)
10701063
|> run(validation_data, model_state)
10711064
|> Access.get(0)
10721065
|> Map.new(fn {k, v} ->
@@ -1733,8 +1726,7 @@ defmodule Axon.Loop do
17331726
end
17341727
end
17351728

1736-
{_, state} = fire_event(status, handler_fns, state, debug?)
1737-
state = %State{state | metrics: final_metrics}
1729+
state = %State{state | metrics: final_metrics, status: status}
17381730

17391731
output_transform.(state)
17401732
end

lib/axon/loop/state.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,14 @@ defmodule Axon.Loop.State do
4242
4343
`event_counts` is a metadata field which stores information about the number
4444
of times each event has been fired. This is useful when creating custom filters.
45+
46+
`status` refers to the loop state status after the loop has executed. You can
47+
use this to determine if the loop ran to completion or if it was halted early.
4548
"""
4649
@enforce_keys [:step_state]
4750
defstruct [
4851
:step_state,
52+
:status,
4953
handler_metadata: %{},
5054
epoch: 0,
5155
max_epoch: 1,

mix.exs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ defmodule Axon.MixProject do
22
use Mix.Project
33

44
@source_url "https://github.com/elixir-nx/axon"
5-
@version "0.4.1"
5+
@version "0.5.0"
66

77
def project do
88
[
@@ -35,9 +35,9 @@ defmodule Axon.MixProject do
3535
# Run "mix help deps" to learn about dependencies.
3636
defp deps do
3737
[
38-
{:exla, "~> 0.4.0", [only: :test] ++ exla_opts()},
39-
{:torchx, "~> 0.4.0", [only: :test] ++ torchx_opts()},
40-
{:nx, "~> 0.4.0", nx_opts()},
38+
{:exla, "~> 0.5.0", [only: :test] ++ exla_opts()},
39+
{:torchx, "~> 0.5.0", [only: :test] ++ torchx_opts()},
40+
{:nx, "~> 0.5.0", nx_opts()},
4141
{:ex_doc, "~> 0.23", only: :docs},
4242
{:table_rex, "~> 3.1.1", optional: true},
4343
{:kino, "~> 0.7", optional: true},
@@ -57,23 +57,23 @@ defmodule Axon.MixProject do
5757
if path = System.get_env("AXON_NX_PATH") do
5858
[path: path, override: true]
5959
else
60-
[github: "elixir-nx/nx", sparse: "nx", override: true]
60+
[]
6161
end
6262
end
6363

6464
defp exla_opts do
6565
if path = System.get_env("AXON_EXLA_PATH") do
6666
[path: path]
6767
else
68-
[github: "elixir-nx/nx", sparse: "exla", override: true]
68+
[]
6969
end
7070
end
7171

7272
defp torchx_opts do
7373
if path = System.get_env("AXON_TORCHX_PATH") do
7474
[path: path]
7575
else
76-
[github: "elixir-nx/nx", sparse: "torchx", override: true]
76+
[]
7777
end
7878
end
7979

mix.lock

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
%{
22
"castore": {:hex, :castore, "0.1.22", "4127549e411bedd012ca3a308dede574f43819fe9394254ca55ab4895abfa1a2", [:mix], [], "hexpm", "c17576df47eb5aa1ee40cc4134316a99f5cad3e215d5c77b8dd3cfef12a22cac"},
33
"cc_precompiler": {:hex, :cc_precompiler, "0.1.5", "ac3ef86f31ab579b856192a948e956cc3e4bb5006e303c4ab4b24958108e218a", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "ee5b2e56eb03798231a3d322579fff509139a534ef54205d04c188e18cab1f57"},
4-
"complex": {:hex, :complex, "0.4.3", "84db4aad241099a8785446ac6eacf498bf3a60634a0e45c7745d875714ddbf98", [:mix], [], "hexpm", "2ceda96ebddcc22697974f1a2666d4cc5dfdd34f8cd8c4f9dced037bcb41eeb5"},
4+
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
55
"dll_loader_helper": {:hex, :dll_loader_helper, "0.1.10", "ba85d66f82c1748513dbaee71aa9d0593bb9a65dba246b980753c4d683b0a07b", [:make, :mix], [{:castore, ">= 0.0.0", [hex: :castore, repo: "hexpm", optional: false]}, {:cc_precompiler, "~> 0.1", [hex: :cc_precompiler, repo: "hexpm", optional: false]}], "hexpm", "c0d02a2d8cd0085252f7551a343f89060bb7beb3f303d991e46a7370ed257485"},
66
"earmark_parser": {:hex, :earmark_parser, "1.4.30", "0b938aa5b9bafd455056440cdaa2a79197ca5e693830b4a982beada840513c5f", [:mix], [], "hexpm", "3b5385c2d36b0473d0b206927b841343d25adb14f95f0110062506b300cd5a1b"},
77
"elixir_make": {:hex, :elixir_make, "0.7.3", "c37fdae1b52d2cc51069713a58c2314877c1ad40800a57efb213f77b078a460d", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "24ada3e3996adbed1fa024ca14995ef2ba3d0d17b678b0f3f2b1f66e6ce2b274"},
88
"ex_doc": {:hex, :ex_doc, "0.29.1", "b1c652fa5f92ee9cf15c75271168027f92039b3877094290a75abcaac82a9f77", [:mix], [{:earmark_parser, "~> 1.4.19", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "b7745fa6374a36daf484e2a2012274950e084815b936b1319aeebcf7809574f6"},
9-
"exla": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "exla"]},
9+
"exla": {:hex, :exla, "0.5.0", "a002cb70e59c26d4ec78a256489e4026c428ff4917f25d266e6a86c58636dc7f", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.5.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.4.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "9219366cb0ea18c421349b8e0f130d85e83d7404df8054e5af6e18a47540c886"},
1010
"kino": {:hex, :kino, "0.8.1", "da3b2cba121b7542146cffdb8af055fa0129395fa67aead9e7e3df93aed1f107", [:mix], [{:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "da45dd141db30db18973de0e3398bda3ab8cb0b5da58d6a0debbe5b864aba295"},
1111
"kino_vega_lite": {:hex, :kino_vega_lite, "0.1.7", "c93fdfe6e35c4c5a4f8afd51a89786b2187e5a7da4595b13ea02a4329d9f0976", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.4", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "59ee442f0532266749d15dc9af4e2875bec61ccfa1b07636bc396ee63dfde8e7"},
1212
"makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"},
1313
"makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"},
1414
"makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"},
1515
"nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"},
16-
"nx": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "nx"]},
16+
"nx": {:hex, :nx, "0.5.0", "c5e62e82606ff372d986e72cce505c98421bb4305ce9cc8e439fe6cc1966c6ad", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "b29c246318181c3ebfcf0f230a0d33783ac4c92dfa34ca3aa5b9b38ae58c187e"},
1717
"table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"},
1818
"table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"},
1919
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
20-
"torchx": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "torchx"]},
20+
"torchx": {:hex, :torchx, "0.5.0", "d787ea5a62f299a93c03a7a9f1d0d903dd854797e8fc27bbbee984d8e3e6acf1", [:make, :mix], [{:dll_loader_helper, "~> 0.1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.5.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "832205d22259011930231e5203cc1b929136a3ad1b160e1f4690d35dfb11ddbd"},
2121
"vega_lite": {:hex, :vega_lite, "0.1.6", "145ab4908bc890b02cef3526e890e9b899528eaa7aa9d6fa642b52a8a2c682c6", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "078c0d8cd9a8eca4ae8f9527c45c01d69cefb6b2235fd5179a227ac2f031d7ac"},
2222
"xla": {:hex, :xla, "0.4.3", "cf6201aaa44d990298996156a83a16b9a87c5fbb257758dbf4c3e83c5e1c4b96", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "caae164b56dcaec6fbcabcd7dea14303afde07623b0cfa4a3cd2576b923105f5"},
2323
}

test/axon/loop_test.exs

Lines changed: 9 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ defmodule Axon.LoopTest do
360360
Axon.input("input", shape: {nil, 1})
361361
|> Axon.dense(1)
362362
|> Loop.trainer(:binary_cross_entropy, :sgd, log: 0)
363-
|> Loop.handle(
363+
|> Loop.handle_event(
364364
:epoch_completed,
365365
fn %State{step_state: pstate} = state ->
366366
{
@@ -376,14 +376,6 @@ defmodule Axon.LoopTest do
376376
}
377377
end
378378
)
379-
|> Loop.handle(
380-
:completed,
381-
fn %State{step_state: %{counter: counter}} = state ->
382-
assert 4 = counter
383-
384-
{:continue, state}
385-
end
386-
)
387379
|> Loop.run(
388380
[{Nx.tensor([[1.0]]), Nx.tensor([[1.0]])}],
389381
%{},
@@ -396,7 +388,7 @@ defmodule Axon.LoopTest do
396388
Axon.input("input", shape: {nil, 1})
397389
|> Axon.dense(1)
398390
|> Loop.trainer(:binary_cross_entropy, :sgd, log: 0)
399-
|> Loop.handle(
391+
|> Loop.handle_event(
400392
:epoch_completed,
401393
fn %State{step_state: pstate} = state ->
402394
{
@@ -416,14 +408,6 @@ defmodule Axon.LoopTest do
416408
}
417409
end
418410
)
419-
|> Loop.handle(
420-
:completed,
421-
fn %State{step_state: %{counter: counter}} = state ->
422-
assert {{4}, 4} = counter
423-
424-
{:continue, state}
425-
end
426-
)
427411
|> Loop.run(
428412
[{Nx.tensor([[1.0]]), Nx.tensor([[1.0]])}],
429413
%{},
@@ -477,7 +461,7 @@ defmodule Axon.LoopTest do
477461
end
478462

479463
def send_handler(loop, event) do
480-
Axon.Loop.handle(loop, event, fn state ->
464+
Axon.Loop.handle_event(loop, event, fn state ->
481465
send(self(), event)
482466
{:continue, state}
483467
end)
@@ -540,15 +524,6 @@ defmodule Axon.LoopTest do
540524
refute_received :iteration_completed
541525
end
542526

543-
test "fires correctly on :completed" do
544-
ExUnit.CaptureIO.capture_io(fn ->
545-
run_dummy_loop!(:completed, 5, 10)
546-
end)
547-
548-
assert_received :completed
549-
refute_received :completed
550-
end
551-
552527
test "fires correctly on :epoch_halted" do
553528
model = Axon.input("foo")
554529

@@ -562,7 +537,7 @@ defmodule Axon.LoopTest do
562537
ExUnit.CaptureIO.capture_io(fn ->
563538
model
564539
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
565-
|> Axon.Loop.handle(:iteration_completed, fn state ->
540+
|> Axon.Loop.handle_event(:iteration_completed, fn state ->
566541
{:halt_epoch, state}
567542
end)
568543
|> send_handler(:epoch_halted)
@@ -576,30 +551,6 @@ defmodule Axon.LoopTest do
576551
refute_received :epoch_halted
577552
end
578553

579-
test "fires correctly on :halted" do
580-
model = Axon.input("foo")
581-
582-
data =
583-
Stream.repeatedly(fn ->
584-
xs = Nx.tensor([[Enum.random(0..10)]])
585-
ys = Nx.greater(xs, 5)
586-
{xs, ys}
587-
end)
588-
589-
ExUnit.CaptureIO.capture_io(fn ->
590-
model
591-
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
592-
|> Axon.Loop.handle(:iteration_completed, fn state ->
593-
{:halt_loop, state}
594-
end)
595-
|> send_handler(:halted)
596-
|> Axon.Loop.run(data, %{}, epochs: 5, iterations: 10)
597-
end)
598-
599-
assert_received :halted
600-
refute_received :halted
601-
end
602-
603554
test "events fire in order" do
604555
model = Axon.input("foo")
605556

@@ -618,7 +569,6 @@ defmodule Axon.LoopTest do
618569
|> send_handler(:iteration_started)
619570
|> send_handler(:iteration_completed)
620571
|> send_handler(:epoch_completed)
621-
|> send_handler(:completed)
622572
|> Axon.Loop.run(data, %{}, epochs: 1, iterations: 1)
623573
end)
624574

@@ -627,7 +577,6 @@ defmodule Axon.LoopTest do
627577
assert_received :iteration_started
628578
assert_received :iteration_completed
629579
assert_received :epoch_completed
630-
assert_received :completed
631580

632581
refute_received _
633582
end
@@ -651,7 +600,7 @@ defmodule Axon.LoopTest do
651600
end
652601

653602
def send_handler(loop, event, filter) do
654-
Axon.Loop.handle(
603+
Axon.Loop.handle_event(
655604
loop,
656605
event,
657606
fn state ->
@@ -863,7 +812,7 @@ defmodule Axon.LoopTest do
863812
model
864813
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
865814
|> Axon.Loop.from_state(state1)
866-
|> Axon.Loop.handle(:epoch_completed, fn %{epoch: epoch} = state ->
815+
|> Axon.Loop.handle_event(:epoch_completed, fn %{epoch: epoch} = state ->
867816
assert epoch >= 3
868817
{:continue, state}
869818
end)
@@ -888,7 +837,7 @@ defmodule Axon.LoopTest do
888837
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
889838
|> Axon.Loop.metric(:accuracy)
890839
|> Axon.Loop.validate(model, Enum.take(data, 5))
891-
|> Axon.Loop.handle(
840+
|> Axon.Loop.handle_event(
892841
:epoch_completed,
893842
fn %{metrics: metrics} = state ->
894843
assert Map.has_key?(metrics, "validation_accuracy")
@@ -918,7 +867,7 @@ defmodule Axon.LoopTest do
918867
|> Axon.Loop.metric(:accuracy)
919868
|> Axon.Loop.validate(model, Enum.take(data, 5))
920869
|> Axon.Loop.early_stop("validation_accuracy", mode: :max)
921-
|> Axon.Loop.handle(
870+
|> Axon.Loop.handle_event(
922871
:epoch_completed,
923872
fn %{handler_metadata: meta} = state ->
924873
assert %{early_stop: %{"validation_accuracy" => _, :since_last_improvement => _}} =
@@ -1006,7 +955,7 @@ defmodule Axon.LoopTest do
1006955
|> Axon.Loop.metric(:accuracy)
1007956
|> Axon.Loop.validate(model, Enum.take(data, 5))
1008957
|> Axon.Loop.reduce_lr_on_plateau("validation_accuracy", mode: :max)
1009-
|> Axon.Loop.handle(
958+
|> Axon.Loop.handle_event(
1010959
:epoch_completed,
1011960
fn %{handler_metadata: meta} = state ->
1012961
assert %{reduce_lr: %{"validation_accuracy" => _, :since_last_improvement => _}} =

0 commit comments

Comments
 (0)