-
Notifications
You must be signed in to change notification settings - Fork 8
Description
Description
Currently, the conversion to/from nir of the following pytorch function is handled individually in each pytorch-based framework:
torch.nn.Conv2dtorch.nn.Linear
In the future this might apply to other functions like:
torch.nn.Conv1dand other convolution operationstorch.nn.AvgPool2dand other average pooling operationstorch.nn.MaxPool2dand other max pooling operations
This is related to this issue in snntorch: jeshraghian/snntorch#304
Suggestion
To reduce the amount of redundant code, we could implement default mapper functions for these pytorch native operations. This default mapper function could be applied to a nn.Module ore nir.node when the framework specific mapper function does not supply the operation.
Changes for conversion from pytorch to NIR
The suggestion is to implement a default mapper function like _extract_default_model
and call it if there is no NIRNode found by the framework dependent model_map function provided: https://github.com/neuromorphs/NIRTorch/blob/main/nirtorch/to_nir.py#L70
Changes for conversion from NIR to pytorch
Extend function _switch_default_models in https://github.com/neuromorphs/NIRTorch/blob/main/nirtorch/from_nir.py#L188