Skip to content

Axon.Loop.validate doesn't work with models that return a container #612

@pejrich

Description

@pejrich

Axon.container let's us return a map of multiple outputs. For example most of the models in Bumblebee do this. %{logits: ..., encoder_hidden_state: ...}, but I have found that this can cause issues in quite a few places. One is Axon.Display.as_table which doesn't seem to handle maps/containers or Axon.None. And the other is Axon.Loop.validate. In the Axon.Loop.eval_step function it takes the model output and passes it directly into Nx.type/1 and Nx.shape/1 which both don't handle maps/containers.

https://github.com/elixir-nx/axon/blob/35004b489e4a0fd5367445c57d0ce6c96f879073/lib/axon/loop.ex#L481C1-L482C38

I'm a little unclear here why this init function can't just return the model output directly. It seems to try to recreate the shape of the models output with all zeros, but only handles outputs that are a tensor. Forgive my ignorance if there's something I'm missing, but can't this just be reduced from:

{inp, tar}, state ->
        # TODO: Is this expensive
        output = forward_model_fn.(state, inp)
        output_type = Nx.type(output)
        output_shape = Nx.shape(output)
        y_pred = Nx.broadcast(Nx.tensor(0, type: output_type), output_shape)

        %{
          model_state: state,
          y_true: zeros_like(tar),
          y_pred: y_pred
        }

to:

{inp, tar}, state ->
        # TODO: Is this expensive
        output = forward_model_fn.(state, inp)

        %{
          model_state: state,
          y_true: zeros_like(tar),
          y_pred: output
        }

I'm happy to write up a PR if this idea seems reasonable, but admittedly my knowledge of the codebase and the repercussions of this change is close to zero.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions