Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues about OneHotVector/OneHotMatrix #1445

Closed
chengchingwen opened this issue Dec 28, 2020 · 9 comments · Fixed by #1448
Closed

Issues about OneHotVector/OneHotMatrix #1445

chengchingwen opened this issue Dec 28, 2020 · 9 comments · Fixed by #1448

Comments

@chengchingwen
Copy link
Member

Mentioned in #1431

There are some issues about the one hot implementation. I'll list them below.

  1. Higher dimension support

Currently Flux only support 1 and 2 dimensional one hot array. However, we often need more than that. For example, Transformers usually require one hot array with shape (num label, sequence length, batch size) for parallelization. Beside Transformers, image task that do pixel level classification, like semantics segmentation, also need that.

  1. Array interface

There're some array operation that would not only convert one hot array to Boolean array but also copy data back to cpu. For example, if you hcat two gpu OneHotMatrix, the result will be Array{Bool}. But the correct result should be OneHotMatrix{CuArray{OneHotVector,1}}.

  1. Memory consumption

We only support encoded numbers of labels up to 2^32 (max size of UInt32), but we use 64bit for OneHotVector. The problem appears from the use of OneHotMatrix, which is actually a container for Vector{OneHotVector}. Thus we actually use twice the memory than actually needed for storing the redundant number of the label size.

(4. There used to be some problem when using OneHotVector with custom CUDA kernel, but most of the seems to be gone after the CUDA update. I list this here just in case anyone encountered similar problems)

@CarloLucibello
Copy link
Member

Do you have an alternative implementation in Transformers.jl fixing these issues? Can it be retrofitted here, possibly without breaking the interface?

@chengchingwen
Copy link
Member Author

I do have one. I think it won't be too hard to build the same interface above it.

@CarloLucibello
Copy link
Member

great, would you file a PR whenever you have time? Actually, most of that Embeddings module and the MultiHeadAttention layer should be cannibalized by Flux (If you agree with this) and made available for general use outside of Transformers.jl as requested by many in #1431 (comment). OneHot arrays seem a well-isolated piece of functionality to start with

@chengchingwen
Copy link
Member Author

Sure, I'll probably do it around the new year.

Actually, most of that Embeddings module and the MultiHeadAttention layer should be cannibalized by Flux

I'm ok with it. I can make a general version of MultiHeadAttention for Flux. On the other hand, which Embeddings module are we talking about?

@CarloLucibello
Copy link
Member

On the other hand, which Embeddings module are we talking about?

The Embeds.jl in your repo. Actually, maybe that's too much, I don't know, I'm not an expert on these things and I don't know how popular they are.
Maybe just import here something similar to what pytorch has?
https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding
https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag

@chengchingwen
Copy link
Member Author

I remember that we used to have an Embed layer in Flux. Not sure why it's gone. I think moving the entire Embedding module in Transformers to Flux would be too much, but a basic Embed layer definition should be fine.

@DhairyaLGandhi
Copy link
Member

Definitely agree on the embedding layer. Someone mentioned it was non-trivial to have a GPU compliant Embedding layer (@darsnack ?) and this would totally need to be on the list

@darsnack
Copy link
Member

No I think I had said that w.r.t. upsampling layers.

@chengchingwen
Copy link
Member Author

I think the GPU compliant Embedding layer would be easy once we have the gather/scatter support in NNlib

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants