-
Notifications
You must be signed in to change notification settings - Fork 816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Emotion2vec support onnx #2359
Conversation
We're almost done adding support for it. It looks like the only missing part is layer normalization during conversion. |
OK, I would have times to check it after 14 Jan. |
@LauraGPT @thewh1teagle In my assumption, I think the error occurs due to the dynamic computation from the layerNormalization and ONNX does not support this. I am thinking about replacing with the actual mathematical computation of the normalization since the following are equivalent. import torch
import torch.nn.functional as F
x = torch.rand(1, 10)
x_norm_ref = F.layer_norm(x, x.shape)
mean = torch.mean(x, dim=1, keepdim=True)
var = torch.var(x, dim=1, keepdim=True, unbiased=False)
x_norm = (x - mean) / torch.sqrt(var + 1e-5)
torch.all(torch.isclose(x_norm, x_norm_ref))
>> tensor(True) - x = F.layer_norm(x, x.shape)
+ mean = torch.mean(x, dim=1, keepdim=True)
+ var = torch.var(x, dim=1, keepdim=True, unbiased=False)
+ x = (x - mean) / torch.sqrt(var + 1e-5)
+ x = x.view(x.shape[0], -1) |
Did you able to check it? thanks |
I have tested it without any errors. Sorry, I do not know what the errors you ref to? |
I gave up on keep testing it. So thanks to @takipipo seems like it works now |
Thanks, I would check it and release a new version of |
Not an issue in FunASR, but I will really appreciate if you could take a look on it |
From the following feature requests
ONNX? ddlBoJack/emotion2vec#55
emotion2vec onnx #2291
Amazing work did by thewh1teagle@586e81d
Include LayerNorm to the Onnx model
I have made emotion2vec exportable to onnx with dynamic audio sequence length. Sadly, I cannot include LayerNorm to theexport_forward
it produces this errorError Log
How to export
Feel free to review the code and give feedbacks 😁