Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block with multiple inputs fails to compile #540

Closed
mntns opened this issue Nov 1, 2023 · 2 comments · Fixed by #574
Closed

Block with multiple inputs fails to compile #540

mntns opened this issue Nov 1, 2023 · 2 comments · Fixed by #574

Comments

@mntns
Copy link

mntns commented Nov 1, 2023

When passing multiple inputs to a block, the model fails to compile:

Mix.install(
  [
    {:exla, ">= 0.0.0"},
    {:axon, path: "./axon", overwrite: true},
    {:table_rex, "~> 3.1.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

input1 = Axon.input("input1")
input2 = Axon.input("input2")
input3 = Axon.input("input3")
input4 = Axon.input("input4")

reuse = Axon.block(fn x1, x2 ->
  Axon.add(x1, x2)
end)
model_1 = reuse.([input1, input2])
model_2 = reuse.([input3, input4])

out = Axon.container({model_1, model_2})

template = %{
  "input1" => Nx.template({2, 8}, :f32),
  "input2" => Nx.template({2, 8}, :f32),
  "input3" => Nx.template({2, 8}, :f32),
  "input4" => Nx.template({2, 8}, :f32)
}

{init_fn, predict_fn} = Axon.compile(out, template)

The following error is thrown:

** (Axon.CompileError) exception found when compiling layer Axon.Layers.block/3 named block_0:

    ** (UndefinedFunctionError) function Axon.Layers.block/3 is undefined or private
        (axon 0.6.0) Axon.Layers.block(#Nx.Tensor<
      f32[2][8]

      Nx.Defn.Expr
      parameter a:0   f32[2][8]
    >, #Nx.Tensor<
      f32[2][8]

      Nx.Defn.Expr
      parameter a:1   f32[2][8]
    >, [mode: :inference, block_fun: #Function<0.123201197 in file:block_without_start.exs>, block_id: 5])


(pass debug: true to build/compile see where the layer was defined)


Compiling of the model was initiated at:

    (nx 0.6.2) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
    (nx 0.6.2) lib/nx/defn/evaluator.ex:83: Nx.Defn.Evaluator.precompile/3
    (nx 0.6.2) lib/nx/defn/evaluator.ex:61: Nx.Defn.Evaluator.__compile__/4
    (nx 0.6.2) lib/nx/defn/evaluator.ex:54: Nx.Defn.Evaluator.__jit__/5
    (nx 0.6.2) lib/nx/defn.ex:443: Nx.Defn.do_jit_apply/3
    (axon 0.6.0) lib/axon.ex:3645: Axon.compile/4
@mntns
Copy link
Author

mntns commented Nov 1, 2023

Upon further investigation, it seems like Axon.Compiler.recur_model_funs can never match on a %Axon.Node{op: :block} with multiple parents (https://github.com/elixir-nx/axon/blob/main/lib/axon/compiler.ex#L464). Instead, the "generic" case (for built-in Axon.Layers nodes) matches (https://github.com/elixir-nx/axon/blob/main/lib/axon/compiler.ex#L665), which leads to Axon.Layers.block/n being built erroneously.

@seanmor5
Copy link
Contributor

I am still trying to think of a good way to handle these cases. I think for now you should wrap all inputs in a container and then you can match on them

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants