From 0e2cb3b356b4e6091f898dbd2861dedb781a9d1b Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Sat, 11 May 2024 14:33:59 -0400 Subject: [PATCH] Loops as streams --- lib/axon/loop.ex | 501 +++++++++++++++++++--------------------- lib/axon/loop/state.ex | 10 - test/axon/loop_test.exs | 38 +-- 3 files changed, 249 insertions(+), 300 deletions(-) diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 232ab768..30ee5903 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -49,9 +49,7 @@ defmodule Axon.Loop do %Axon.Loop.State{ epoch: integer(), - max_epoch: integer(), iteration: integer(), - max_iteration: integer(), metrics: map(string(), container()), times: map(integer(), integer()), step_state: container() @@ -66,14 +64,6 @@ defmodule Axon.Loop do state. For machine learning tasks, the initialization function will return things like initial model parameters and optimizer state. - Typically, the final output of the loop is the accumulated final state; however, you - may optionally apply an output transform to extract specific values at the end of the - loop. For example, `Axon.Loop.trainer/4` by default extracts trained model state: - - output_transform = fn state -> - state.step_state[:model_state] - end - ## Initialize and Step The core of the Axon loop are the init and step functions. The initialization is an @@ -274,7 +264,6 @@ defmodule Axon.Loop do :init, :step, :attached_state, - :output_transform, metrics: %{}, handlers: @default_handlers ] @@ -531,8 +520,7 @@ defmodule Axon.Loop do ## Loop Factories @doc """ - Creates a loop from `step_fn`, an optional `init_fn`, and an - optional `output_transform`. + Creates a loop from `step_fn` and an optional `init_fn`. `step_fn` is an arity-2 function which takes a batch and state and returns an updated step state: @@ -557,18 +545,12 @@ defmodule Axon.Loop do within `Nx.Defn.jit/3`. While JIT-compilation will work with anonymous functions, `def`, and `defn`, it is recommended that you use the stricter `defn` to define both functions in order to avoid bugs or cryptic errors. - - `output_transform/1` applies a transformation on the final accumulated loop state. - This is useful for extracting specific fields from a loop and piping them into - additional functions. """ - def loop(step_fn, init_fn \\ &default_init/2, output_transform \\ & &1) - when is_function(step_fn, 2) and is_function(init_fn, 2) and - is_function(output_transform, 1) do + def loop(step_fn, init_fn \\ &default_init/2) + when is_function(step_fn, 2) and is_function(init_fn, 2) do %Loop{ init: init_fn, - step: step_fn, - output_transform: output_transform + step: step_fn } end @@ -679,11 +661,10 @@ defmodule Axon.Loop do {init_fn, step_fn} = train_step(model, loss_fn, optimizer, step_opts) log_interval = opts[:log] || 50 - output_transform = fn state -> state.step_state[:model_state] end loop = step_fn - |> loop(init_fn, output_transform) + |> loop(init_fn) |> metric(loss_fn, "loss") if log_interval > 0 do @@ -773,9 +754,9 @@ defmodule Axon.Loop do """ def evaluator(model) do {init_fn, step_fn} = eval_step(model) - output_transform = fn state -> state.metrics end - loop(step_fn, init_fn, output_transform) + step_fn + |> loop(init_fn) |> log(&supervised_log_message_fn(&1, false), event: :iteration_completed) end @@ -1529,7 +1510,7 @@ defmodule Axon.Loop do It is the opposite of `Axon.Loop.serialize_state/2`. - By default, the step state is deserialized using `Nx.deserialize.2`; + By default, the step state is deserialized using `Nx.deserialize/2`; however, this behavior can be changed if step state is an application specific container. For example, if you introduce your own data structure into step_state and you customized the serialization logic, @@ -1547,6 +1528,118 @@ defmodule Axon.Loop do struct!(Axon.Loop.State, state_map) end + @doc """ + Creates an Elixir `Stream` from the given loop, data, and state. + + The stream will lazily return the loop state after each training epoch. + + ## Options + + """ + def stream( + %Loop{ + init: init_fn, + step: step_fn, + handlers: handler_fns, + metrics: metric_fns, + attached_state: attached_state + }, + data, + init_state \\ %{}, + opts \\ [] + ) do + {max_iterations, opts} = Keyword.pop(opts, :iterations, -1) + {jit_compile?, opts} = Keyword.pop(opts, :jit_compile?, true) + {strict?, opts} = Keyword.pop(opts, :strict?, true) + {force_garbage_collection?, jit_opts} = Keyword.pop(opts, :force_garbage_collection?, false) + debug? = Keyword.get(jit_opts, :debug, false) + + sample_data = + case Enum.take(data, 1) do + [sample_data | _] -> + sample_data + + [] -> + raise ArgumentError, + "Axon.Loop.stream/4 received empty dataset, this can happen" <> + " if you've built a stream and accidentally filtered" <> + " out every value, your dataset must have at least one" <> + " entry" + end + + loop_state = + init_loop_state( + init_fn, + sample_data, + init_state, + attached_state, + debug?, + jit_compile?, + jit_opts + ) + + # TODO: Can we infer here? + zero_metrics = Map.new(metric_fns, fn {k, _} -> {k, Nx.tensor(0, type: :f32)} end) + loop_state = %{loop_state | metrics: zero_metrics} + + case fire_event(:started, handler_fns, loop_state, debug?) do + {halt, loop_state} when halt in [:halt_epoch, :halt_loop] -> + Stream.unfold(loop_state, fn _ -> nil end) + + {:continue, loop_state} -> + batch_fn = + {:non_compiled, build_batch_fn(step_fn, metric_fns), jit_compile?, strict?, jit_opts} + + Stream.unfold({loop_state, batch_fn}, fn + :halt -> + nil + + {loop_state, batch_fn} -> + case fire_event(:epoch_started, handler_fns, loop_state, debug?) do + {:halt_epoch, state} -> + halt_epoch(handler_fns, batch_fn, state, debug?) + + {:halt_loop, state} -> + {state, :halt} + + {:continue, state} -> + {time, out} = + run_epoch( + batch_fn, + handler_fns, + state, + data, + max_iterations, + debug?, + force_garbage_collection? + ) + + case out do + {:halt_epoch, batch_fn, state} -> + halt_epoch(handler_fns, batch_fn, state, debug?) + + {:halt_loop, _, state} -> + {state, :halt} + + {:continue, batch_fn, state} -> + new_loop_state = put_in(state.times[0], time) + + case fire_event(:epoch_completed, handler_fns, new_loop_state, debug?) do + {:halt_epoch, state} -> + halt_epoch(handler_fns, batch_fn, state, debug?) + + {:halt_loop, state} -> + {state, :halt} + + {:continue, state} -> + {state, {%{state | epoch: state.epoch + 1, iteration: 0}, batch_fn}} + end + end + end + end) + end + end + @doc """ Runs the given loop on data with the given options. @@ -1588,262 +1681,153 @@ defmodule Axon.Loop do """ def run(loop, data, init_state \\ %{}, opts \\ []) do {max_epochs, opts} = Keyword.pop(opts, :epochs, 1) - {max_iterations, opts} = Keyword.pop(opts, :iterations, -1) - {jit_compile?, opts} = Keyword.pop(opts, :jit_compile?, true) - {strict?, opts} = Keyword.pop(opts, :strict?, true) - {force_garbage_collection?, jit_opts} = Keyword.pop(opts, :force_garbage_collection?, false) - debug? = Keyword.get(jit_opts, :debug, false) - if jit_opts != [] do - Logger.debug("Forwarding options: #{inspect(jit_opts)} to JIT compiler") - end - - %Loop{ - init: init_fn, - step: step_fn, - handlers: handler_fns, - metrics: metric_fns, - attached_state: attached_state, - output_transform: output_transform - } = loop - - sample_data = - case Enum.take(data, 1) do - [sample_data | _] -> - sample_data + loop + |> stream(data, init_state, opts) + |> Stream.take(max_epochs) + |> Stream.with_index() + |> Enum.reduce(%{metrics: %{}, times: %{}}, fn {epoch_state, i}, acc_state -> + %{ + epoch_state + | metrics: Map.put(acc_state.metrics, i + 1, epoch_state.metrics), + times: Map.put(acc_state.times, i + 1, epoch_state.times[0]) + } + end) + end - [] -> - raise ArgumentError, - "Axon.Loop.run received empty dataset, this can happen" <> - " if you've built a stream and accidentally filtered" <> - " out every value, your dataset must have at least one" <> - " entry" - end + ## Helpers + defp init_loop_state( + init_fn, + sample_data, + init_state, + attached_state, + debug?, + jit_compile?, + jit_opts + ) do if debug? do Logger.debug("Axon.Loop started initializing loop state") end - {time, loop_state} = + {time, state} = :timer.tc(fn -> - init_loop_state( - init_fn, - sample_data, - init_state, - attached_state, - max_epochs, - max_iterations, - jit_compile?, - jit_opts - ) + case attached_state do + %State{} = state -> + state + + nil -> + step_state = maybe_jit(init_fn, [sample_data, init_state], jit_compile?, jit_opts) + + %State{ + epoch: 0, + iteration: 0, + step_state: step_state, + metrics: %{}, + times: %{} + } + end end) - epoch_start = loop_state.epoch - epoch_end = max_epochs + epoch_start - 1 - if debug? do Logger.debug("Axon.Loop finished initializing loop state in #{us_to_ms(time)}ms") end - # TODO: Can we infer here? - zero_metrics = Map.new(metric_fns, fn {k, _} -> {k, Nx.tensor(0, type: :f32)} end) - final_metrics_map = loop_state.metrics - loop_state = %{loop_state | metrics: zero_metrics} - - {status, final_metrics_map, state} = - case fire_event(:started, handler_fns, loop_state, debug?) do - {:halt_epoch, state} -> - {:halted, final_metrics_map, state} - - {:halt_loop, state} -> - {:halted, final_metrics_map, state} - - {:continue, state} -> - batch_fn = - {:non_compiled, build_batch_fn(step_fn, metric_fns), jit_compile?, strict?, jit_opts} - - Enum.reduce_while( - epoch_start..epoch_end//1, - {batch_fn, final_metrics_map, state}, - fn epoch, {batch_fn, final_metrics_map, loop_state} -> - case fire_event(:epoch_started, handler_fns, loop_state, debug?) do - {:halt_epoch, state} -> - halt_epoch(handler_fns, batch_fn, final_metrics_map, state, debug?) - - {:halt_loop, state} -> - {:halt, {final_metrics_map, state}} - - {:continue, state} -> - if debug? do - Logger.debug("Axon.Loop started running epoch #{epoch}") - end - - {time, status_batch_fn_and_state} = - :timer.tc(&run_epoch/6, [ - batch_fn, - handler_fns, - state, - data, - debug?, - force_garbage_collection? - ]) - - if debug? do - Logger.debug("Axon.Loop finished running epoch in #{us_to_ms(time)} ms") - end - - case status_batch_fn_and_state do - {:halt_epoch, batch_fn, state} -> - halt_epoch(handler_fns, batch_fn, final_metrics_map, state, debug?) - - {:halt_loop, _, state} -> - {:halt, {final_metrics_map, state}} - - {:continue, batch_fn, state} -> - new_loop_state = put_in(state.times[epoch], time) - - case fire_event(:epoch_completed, handler_fns, new_loop_state, debug?) do - {:halt_epoch, state} -> - halt_epoch(handler_fns, batch_fn, final_metrics_map, state, debug?) - - {:halt_loop, state} -> - {:halt, {final_metrics_map, state}} - - {:continue, state} -> - {:cont, - {batch_fn, Map.put(final_metrics_map, epoch, state.metrics), - %State{ - state - | epoch: epoch + 1, - metrics: zero_metrics, - iteration: 0, - max_iteration: state.max_iteration - }}} - end - end - end - end - ) - |> case do - {final_metrics_map, state} -> {:halted, final_metrics_map, state} - {_batch_fn, final_metrics_map, state} -> {:completed, final_metrics_map, state} - end - end - - # Fill in epochs in case it was halted. It is a no-op otherwise. - final_metrics_map = - Enum.reduce( - state.epoch..epoch_end//1, - final_metrics_map, - &Map.put(&2, &1, zero_metrics) - ) - - state = %State{state | metrics: final_metrics_map, status: status} - output_transform.(state) + state end - ## Helpers - - defp init_loop_state( - init_fn, - sample_data, - init_state, - attached_state, - max_epochs, - max_iterations, - jit_compile?, - jit_opts + defp run_epoch( + batch_fn, + handler_fns, + loop_state, + data, + max_iters, + debug?, + force_garbage_collection? ) do - case attached_state do - %State{} = state -> - %{state | max_epoch: max_epochs + state.epoch} - - nil -> - step_state = maybe_jit(init_fn, [sample_data, init_state], jit_compile?, jit_opts) - - %State{ - epoch: 0, - max_epoch: max_epochs, - iteration: 0, - max_iteration: max_iterations, - step_state: step_state, - metrics: %{}, - times: %{} - } + if debug? do + Logger.debug("Axon.Loop started running epoch") end - end - defp run_epoch(batch_fn, handler_fns, loop_state, data, debug?, force_garbage_collection?) do - Enum.reduce_while(data, {:continue, batch_fn, loop_state}, fn data, {_, batch_fn, state} -> - case fire_event(:iteration_started, handler_fns, state, debug?) do - {:halt_epoch, state} -> - {:halt, {:halt_epoch, batch_fn, state}} - - {:halt_loop, state} -> - {:halt, {:halt_loop, batch_fn, state}} - - {:continue, state} -> - %State{ - iteration: iters, - max_iteration: max_iters, - step_state: step_state, - metrics: metrics - } = state - - batch_fn = - case batch_fn do - {:non_compiled, batch_fn, jit_compile?, strict?, jit_opts} -> - cond do - jit_compile? and strict? -> - Nx.Defn.compile(batch_fn, [data, iters, step_state, metrics], jit_opts) - - jit_compile? -> - Nx.Defn.jit(batch_fn, jit_opts) - - true -> + {time, out} = + :timer.tc(fn -> + Enum.reduce_while(data, {:continue, batch_fn, loop_state}, fn data, + {_, batch_fn, state} -> + case fire_event(:iteration_started, handler_fns, state, debug?) do + {:halt_epoch, state} -> + {:halt, {:halt_epoch, batch_fn, state}} + + {:halt_loop, state} -> + {:halt, {:halt_loop, batch_fn, state}} + + {:continue, state} -> + %State{ + iteration: iters, + step_state: step_state, + metrics: metrics + } = state + + batch_fn = + case batch_fn do + {:non_compiled, batch_fn, jit_compile?, strict?, jit_opts} -> + cond do + jit_compile? and strict? -> + Nx.Defn.compile(batch_fn, [data, iters, step_state, metrics], jit_opts) + + jit_compile? -> + Nx.Defn.jit(batch_fn, jit_opts) + + true -> + batch_fn + end + + {:compiled, batch_fn} -> batch_fn end - {:compiled, batch_fn} -> - batch_fn - end - - if debug? do - Logger.debug("Axon.Loop started batch step execution") - end + if debug? do + Logger.debug("Axon.Loop started batch step execution") + end - {time, {new_step_state, new_metrics}} = - :timer.tc(fn -> batch_fn.(data, iters, step_state, metrics) end) + {time, {new_step_state, new_metrics}} = + :timer.tc(fn -> batch_fn.(data, iters, step_state, metrics) end) - if debug? do - Logger.debug("Axon.Loop finished batch step execution in #{us_to_ms(time)}ms") - end + if debug? do + Logger.debug("Axon.Loop finished batch step execution in #{us_to_ms(time)}ms") + end - batch_fn = {:compiled, batch_fn} - state = %{state | step_state: new_step_state, metrics: new_metrics} + batch_fn = {:compiled, batch_fn} + state = %{state | step_state: new_step_state, metrics: new_metrics} - case fire_event(:iteration_completed, handler_fns, state, debug?) do - {:halt_epoch, state} -> - {:halt, {:halt_epoch, batch_fn, state}} + case fire_event(:iteration_completed, handler_fns, state, debug?) do + {:halt_epoch, state} -> + {:halt, {:halt_epoch, batch_fn, state}} - {:halt_loop, state} -> - {:halt, {:halt_loop, batch_fn, state}} + {:halt_loop, state} -> + {:halt, {:halt_loop, batch_fn, state}} - {:continue, state} -> - if force_garbage_collection? do - :erlang.garbage_collect() - end + {:continue, state} -> + if force_garbage_collection? do + :erlang.garbage_collect() + end - state = %{state | iteration: iters + 1} + state = %{state | iteration: iters + 1} - if max_iterations_reached?(max_iters, iters) do - {:halt, {:continue, batch_fn, state}} - else - {:cont, {:continue, batch_fn, state}} + if max_iterations_reached?(max_iters, iters) do + {:halt, {:continue, batch_fn, state}} + else + {:cont, {:continue, batch_fn, state}} + end end end - end - end) + end) + end) + + if debug? do + Logger.debug("Axon.Loop finished running epoch in #{us_to_ms(time)} ms") + end + + {time, out} end defp max_iterations_reached?(max_iters, iters) do @@ -1917,17 +1901,16 @@ defmodule Axon.Loop do end # Halts an epoch during looping - defp halt_epoch(handler_fns, batch_fn, final_metrics_map, loop_state, debug?) do + defp halt_epoch(handler_fns, batch_fn, loop_state, debug?) do case fire_event(:epoch_halted, handler_fns, loop_state, debug?) do - {:halt_epoch, %{epoch: epoch, metrics: metrics} = state} -> - final_metrics_map = Map.put(final_metrics_map, epoch, metrics) - {:cont, {batch_fn, final_metrics_map, %State{state | epoch: epoch + 1, iteration: 0}}} + {:continue, state} -> + {:continue, batch_fn, state} - {:halt_loop, state} -> - {:halt, {final_metrics_map, state}} + {:halt_epoch, %{epoch: epoch} = state} -> + {:continue, batch_fn, %State{state | epoch: epoch + 1, iteration: 0}} - {:continue, state} -> - {:cont, {batch_fn, final_metrics_map, state}} + {:halt_loop, _state} -> + :halt end end diff --git a/lib/axon/loop/state.ex b/lib/axon/loop/state.ex index eaccef22..b885bf6c 100644 --- a/lib/axon/loop/state.ex +++ b/lib/axon/loop/state.ex @@ -6,9 +6,7 @@ defmodule Axon.Loop.State do %State{ epoch: integer(), - max_epoch: integer(), iteration: integer(), - max_iteration: integer(), metrics: map(string(), container()), times: map(integer(), integer()), step_state: container(), @@ -18,15 +16,9 @@ defmodule Axon.Loop.State do `epoch` is the current epoch, starting at 0, of the nested loop. Defaults to 0. - `max_epoch` is the maximum number of epochs the loop should run - for. Defaults to 1. - `iteration` is the current iteration of the inner loop. In supervised settings, this will be the current batch. Defaults to 0. - `max_iteration` is the maximum number of iterations the loop should - run a given epoch for. Defaults to -1 (no max). - `metrics` is a map of `%{"metric_name" => value}` which accumulates metrics over the course of loop processing. Defaults to an empty map. @@ -52,9 +44,7 @@ defmodule Axon.Loop.State do :status, handler_metadata: %{}, epoch: 0, - max_epoch: 1, iteration: 0, - max_iteration: -1, metrics: %{}, times: %{}, event_counts: %{ diff --git a/test/axon/loop_test.exs b/test/axon/loop_test.exs index 34c002ab..ac5c3b84 100644 --- a/test/axon/loop_test.exs +++ b/test/axon/loop_test.exs @@ -9,12 +9,11 @@ defmodule Axon.LoopTest do test "loop/3 creates a basic loop with defaults" do step_fn = fn _, _ -> Nx.tensor(1) end - assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + assert %Loop{init: init_fn, step: update_fn} = Loop.loop(step_fn) assert_equal(init_fn.(Nx.tensor(1), %{}), %{}) assert_equal(update_fn.({}, %{}), Nx.tensor(1)) - assert_equal(transform.(%{}), %{}) end test "trainer/3 returns a supervised training loop with basic case" do @@ -40,7 +39,7 @@ defmodule Axon.LoopTest do for loss <- valid_axon_losses do for optimizer <- valid_axon_optimizers do - assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + assert %Loop{init: init_fn, step: update_fn} = Loop.trainer(model, loss, optimizer) assert %{model_state: %Axon.ModelState{}} = @@ -54,8 +53,6 @@ defmodule Axon.LoopTest do assert_equal(tar, Nx.tensor([[1]])) assert_equal(pred, Nx.tensor([[1]])) - - assert_equal(transform.(state), %{}) end end end @@ -64,7 +61,7 @@ defmodule Axon.LoopTest do model = Axon.input("input", shape: {nil, 1}) custom_loss_fn = fn _, _ -> Nx.tensor(5.0, backend: Nx.BinaryBackend) end - assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + assert %Loop{init: init_fn, step: update_fn} = Loop.trainer(model, custom_loss_fn, :adam) assert %{model_state: %{}} = @@ -78,15 +75,13 @@ defmodule Axon.LoopTest do assert_equal(tar, Nx.tensor([[1]])) assert_equal(pred, Nx.tensor([[1]])) assert_equal(loss, Nx.tensor(5.0)) - - assert_equal(transform.(state), %{}) end test "trainer/3 returns a supervised training loop with custom optimizer" do model = Axon.input("input", shape: {nil, 1}) optimizer = Polaris.Optimizers.rmsprop(learning_rate: 1.0e-3) - assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + assert %Loop{init: init_fn, step: update_fn} = Loop.trainer(model, :mean_squared_error, optimizer) assert %{model_state: %{}} = @@ -99,14 +94,12 @@ defmodule Axon.LoopTest do assert_equal(tar, Nx.tensor([[1]])) assert_equal(pred, Nx.tensor([[1]])) - - assert_equal(transform.(state), %{}) end test "trainer/3 returns a supervised training loop with custom model" do model = Axon.input("input", shape: {nil, 1}) |> Axon.build(mode: :train) - assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + assert %Loop{init: init_fn, step: update_fn} = Loop.trainer(model, :mean_squared_error, :adam) assert %{model_state: %{}} = @@ -119,8 +112,6 @@ defmodule Axon.LoopTest do assert_equal(tar, Nx.tensor([[1]])) assert_equal(pred, Nx.tensor([[1]])) - - assert_equal(transform.(state), %{}) end test "trainer/3 returns a supervised training loop with multi-loss" do @@ -128,7 +119,7 @@ defmodule Axon.LoopTest do {Axon.input("input_0", shape: {nil, 1}), Axon.input("input_1", shape: {nil, 1})} |> Axon.container() - assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + assert %Loop{init: init_fn, step: update_fn} = Loop.trainer(model, [mean_squared_error: 0.5, mean_absolute_error: 0.5], :adam) assert %{model_state: %{}} = @@ -146,8 +137,6 @@ defmodule Axon.LoopTest do assert_equal(tar, {Nx.tensor([[2]]), Nx.tensor([[2]])}) assert_equal(pred, {Nx.tensor([[1]]), Nx.tensor([[1]])}) assert_equal(loss, Nx.tensor(1.0)) - - assert_equal(transform.(state), %{}) end test "trainer/3 raises on bad inputs" do @@ -174,7 +163,7 @@ defmodule Axon.LoopTest do expected_pred = Axon.predict(model, model_state, inp) - assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + assert %Loop{init: init_fn, step: update_fn} = Loop.evaluator(model) assert %{model_state: _, y_true: _, y_pred: _} = @@ -187,8 +176,6 @@ defmodule Axon.LoopTest do assert_equal(tar, Nx.tensor([[2]])) assert_equal(pred, expected_pred) - - assert_equal(transform.(state), %{"my_metric" => {}}) end test "evaluator/1 runs a supervised evaluator loop" do @@ -856,7 +843,6 @@ defmodule Axon.LoopTest do assert %Axon.Loop.State{epoch: 3, iteration: 0, event_counts: %{iteration_completed: 15}} = loop - |> Map.put(:output_transform, & &1) |> Loop.checkpoint(event: :iteration_completed, filter: [every: 2]) |> Loop.run(data, Axon.ModelState.empty(), epochs: 3) @@ -907,8 +893,6 @@ defmodule Axon.LoopTest do state1 = model |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) - # TODO: Make this an actual function or configurable - |> Map.put(:output_transform, & &1) |> Axon.Loop.run(data, Axon.ModelState.empty(), epochs: 3, iterations: 5) model @@ -1003,8 +987,6 @@ defmodule Axon.LoopTest do |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) |> Axon.Loop.metric(my_metric, "counter", :running_sum) |> Axon.Loop.early_stop("counter", mode: :min, patience: 2) - # TODO: This API needs to change - |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.run(data, Axon.ModelState.empty(), epochs: 5, iterations: 5) assert %{epoch: 3} = state @@ -1031,8 +1013,6 @@ defmodule Axon.LoopTest do |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) |> Axon.Loop.metric(my_metric, "counter", :running_sum) |> Axon.Loop.early_stop("counter", mode: :max, patience: 2) - # TODO: This API needs to change - |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.run(data, Axon.ModelState.empty(), epochs: 5, iterations: 5) assert %{epoch: 3} = state @@ -1096,8 +1076,6 @@ defmodule Axon.LoopTest do ) |> Axon.Loop.metric(my_metric, "counter", :running_sum) |> Axon.Loop.reduce_lr_on_plateau("counter", factor: 0.5, mode: :min, patience: 2) - # TODO: This API needs to change - |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.run(data, Axon.ModelState.empty(), epochs: 7, iterations: 5) assert %{step_state: %{optimizer_state: optimizer_state}} = state @@ -1132,8 +1110,6 @@ defmodule Axon.LoopTest do ) |> Axon.Loop.metric(my_metric, "counter", :running_sum) |> Axon.Loop.reduce_lr_on_plateau("counter", factor: 0.5, mode: :max, patience: 2) - # TODO: This API needs to change - |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.run(data, Axon.ModelState.empty(), epochs: 7, iterations: 5) assert %{step_state: %{optimizer_state: optimizer_state}} = state