Generalize the Fourier transform API #86
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR is a proposal to refactor the Fourier transform API, with the goal of making it easier to incorporate Fourier transforms into other modules. Here are the two specific use-cases I was trying to facilitate:
A Fourier max pooling layer, as discussed in Question about pooling with Fourier representations in R^3 #65. This would be very similar to the existing
FourierPointwise
class, except after the nonlinearity, there would also be max-pooling and Gaussian blurring steps.An IFT as an output layer. It's probably not clear what I mean by that, and it's possible that I only needed such a layer because I overlooked some easier way of doing things, so I want to take some time to explain the problem I was trying to solve. My goal was to reimplement [Doersch2016], but in 3D and with equivariance. The idea in [Doersch2016] is to create a self-supervised training protocol by taking two nearby crops of an image, and having the model predict the location of the second relative to the first. There would only be a handful of possible relative locations, e.g. above, below, right, and left (for 2D images). I implemented this by having the final layer of my model be a single spectral regular representation (of the quotient space$S^2 = SO(3) / SO(2)$ , because the two crops cannot rotate relative to each other), then performing an IFT with each grid point corresponding to one of the possible relative locations. This results in values for each location that can be interpreted as logits. And if the input rotates, so do the logits. To bring this back to the PR at hand, the important point is that this application requires being able to perform an IFT without a subsequent FT.
I think that the best way to support these two use-cases, and possibly others that I haven't thought of, is to create separate FT and IFT modules. That's what the proposed API does. Here are the specific classes involved:
InverseFourierTransform
: A pytorch module where the input is a tensor with a spectral regular representation, and the output is a tensor of signal values sampled on a grid.FourierTransform
: The opposite ofInverseFourierTransform
. This module also provides the option to prepare the FT matrix with more irreps than will ultimately be output.FourierFieldType
: Most equivariant modules accept input/output field types as arguments, butFourierPointwise
is an exception. It acceptsgspace
,channels
, andirreps
arguments, and uses them to create a compatible field type under the hood. This API is a bit awkward to begin with, but it's worse when the same arguments need to be passed to two different modules.To bring the Fourier API in line with all the other modules, I created
FourierFieldType
. This is a subclass ofFieldType
that only allows spectral regular representations (possibly with respect to a quotient space). The IFT and FT modules require this field type (and check for it). Other modules are agnostic to it.GridTensor
: A class that wraps the output of an IFT and the input to an FT. It's similar in concept toGeometricTensor
, except that instead of keeping track of the representation associated with a tensor, it keeps track of the grid. This lets the FT module check that it's compatible with the input it receives, and (for GNNs) restore thecoords
attribute.Using these classes, I reimplemented the
FourierPointwise
class in a way that I believe to be 100% backwards-compatible. The new implementation also removes hundreds of lines of code that were duplicated betweenFourierPointwise
andQuotientFourierPointwise
. Below is a simplifiedFourierRelu
version of this class, just to give a sense for how it works:Minor comments:
This PR isn't ready to be merged yet. I haven't updated the documentation, and although all the existing tests pass, I want to write some new tests as well. But before I spend a lot of time on those tasks, I want to know if there's any interest in merging this.
I haven't implemented the aforementioned Fourier max pooling module yet. But if there's interest, I could add that to the PR as well.