Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing f32 data into LSTM with Axon.Loop trainer+run causes while shape mismatch error #490

Closed
polvalente opened this issue Apr 17, 2023 · 9 comments

Comments

@polvalente
Copy link
Contributor

input = Axon.input("input_series", shape: put_elem(Nx.shape(time_x), 0, nil))

model =
  input
  |> Axon.lstm(128, activation: :relu)
  |> elem(0)
  |> Axon.dropout(rate: 0.25)
  |> Axon.lstm(64, activation: :relu)
  |> elem(0)
  |> Axon.dropout(rate: 0.25)
  |> Axon.lstm(32, activation: :relu)
  |> elem(0)
  |> Axon.dropout(rate: 0.25)
  |> Axon.dense(1)

model
|> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adam())
|> Axon.Loop.run(Stream.zip(Nx.to_batched(time_x, 50), Nx.to_batched(Nx.new_axis(time_y, 2), 50)))
@seanmor5
Copy link
Contributor

Btw, looking at this, it's not advisable to use dropout after an LSTM layer. See https://arxiv.org/pdf/1512.05287.pdf

This is still a bug though

@polvalente
Copy link
Contributor Author

I was just copying a Kaggle solution to practice RNNs :)

Thanks for the advice!

@krstopro
Copy link
Member

@polvalente Looking into this, but I am not getting the error. Could you please tell me what are the shapes of time_x and time_y? Thanks!

@polvalente
Copy link
Contributor Author

@krstopro I pivoted from this approach, but I believe that x and y were just rank 1 tensors.

The error appears for floating point inputs, but not integer inputs IIRC

@krstopro
Copy link
Member

krstopro commented Jun 15, 2023

@polvalente time_y is for sure not rank 1, as third dimension is added with Nx.new_axis(time_y, 2).

The following code seems to be working for me (x and y are f32).

key = Nx.Random.key(12)
{x, _new_key} = Nx.Random.normal(key, shape: {12, 6, 3})
{y, _new_key} = Nx.Random.normal(key, shape: {12, 6})

input = Axon.input("input_series", shape: put_elem(Nx.shape(x), 0, nil))

model =
  input
  |> Axon.lstm(128, activation: :relu)
  |> elem(0)
  |> Axon.dropout(rate: 0.25)
  |> Axon.lstm(64, activation: :relu)
  |> elem(0)
  |> Axon.dropout(rate: 0.25)
  |> Axon.lstm(32, activation: :relu)
  |> elem(0)
  |> Axon.dropout(rate: 0.25)
  |> Axon.dense(1)

model
|> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adam())
|> Axon.Loop.run(Stream.zip(Nx.to_batched(x, 6), Nx.to_batched(Nx.new_axis(y, 2), 6)))

So, everything seems legit, but I might need to further inspect the output shapes of LSTMs.
I am not sure if LSTM returns just the last hidden state or all of them throughout the time.
Also, I don't know if the first dimension should be batch or time (e.g. in PyTorch it's time or length https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html).

@polvalente
Copy link
Contributor Author

@seanmor5 did you fix this bug at the time?

@krstopro the problem I ran into was related to some of the inner random key states being upcast from u32 to f32

@danieljaouen
Copy link

I am also seeing an issue which I think has to do with my model changing to :f32 (but from :f64 in my case). See: https://elixirforum.com/t/getting-batches-to-work-with-axon/57482/3

@dc0d
Copy link

dc0d commented Nov 5, 2023

Facing the same issue.

Manged to solve it using Nx.as_type(df, :f32) (but the original problem still remains):

...
  def df_to_tensor(df) do
    df
    |> Explorer.DataFrame.names()
    |> Enum.map(&(df[&1] 
      |> Explorer.Series.to_tensor() 
      |> Nx.new_axis(-1)
      |> Nx.as_type(:f32)))
    |> Nx.concatenate(axis: 1)
  end
...

@seanmor5
Copy link
Contributor

This issue should be fixed with the new Axon.ModelState changes - dropout keys and other model state are no longer considered part of the training parameters and so shouldn't accidentally get cast anywhere

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants