-
Notifications
You must be signed in to change notification settings - Fork 116
Description
Axon.container let's us return a map of multiple outputs. For example most of the models in Bumblebee do this. %{logits: ..., encoder_hidden_state: ...}, but I have found that this can cause issues in quite a few places. One is Axon.Display.as_table which doesn't seem to handle maps/containers or Axon.None. And the other is Axon.Loop.validate. In the Axon.Loop.eval_step function it takes the model output and passes it directly into Nx.type/1 and Nx.shape/1 which both don't handle maps/containers.
I'm a little unclear here why this init function can't just return the model output directly. It seems to try to recreate the shape of the models output with all zeros, but only handles outputs that are a tensor. Forgive my ignorance if there's something I'm missing, but can't this just be reduced from:
{inp, tar}, state ->
# TODO: Is this expensive
output = forward_model_fn.(state, inp)
output_type = Nx.type(output)
output_shape = Nx.shape(output)
y_pred = Nx.broadcast(Nx.tensor(0, type: output_type), output_shape)
%{
model_state: state,
y_true: zeros_like(tar),
y_pred: y_pred
}to:
{inp, tar}, state ->
# TODO: Is this expensive
output = forward_model_fn.(state, inp)
%{
model_state: state,
y_true: zeros_like(tar),
y_pred: output
}I'm happy to write up a PR if this idea seems reasonable, but admittedly my knowledge of the codebase and the repercussions of this change is close to zero.