Skip to content

Commit

Permalink
Merge pull request #8 from Arne-Thomsen/larger_tensors
Browse files Browse the repository at this point in the history
added support for larger tensor sizes
  • Loading branch information
jafluri authored Aug 30, 2023
2 parents 6693a79 + e582fc3 commit 3facedf
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 16 deletions.
25 changes: 17 additions & 8 deletions deepsphere/gnn_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Chebyshev(Model):
"""

def __init__(self, L, K, Fout=None, initializer=None, activation=None, use_bias=False,
use_bn=False, **kwargs):
use_bn=False, n_matmul_splits=1, **kwargs):
"""
Initializes the graph convolutional layer, assuming the input has dimension (B, M, F)
:param L: The graph Laplacian (MxM), as numpy array
Expand All @@ -23,6 +23,8 @@ def __init__(self, L, K, Fout=None, initializer=None, activation=None, use_bias=
:param activation: the activation function to use after the layer, defaults to linear
:param use_bias: Use learnable bias weights
:param use_bn: Apply batch norm before adding the bias
:param n_matmul_splits: Number of splits to apply to axis 1 of the dense tensor in the
tf.sparse.sparse_dense_matmul operations to avoid the operation's size limitation
:param kwargs: additional keyword arguments passed on to add_weight
"""

Expand All @@ -44,6 +46,7 @@ def __init__(self, L, K, Fout=None, initializer=None, activation=None, use_bias=
self.activation = getattr(tf.keras.activations, activation)
else:
raise ValueError(f"Could not find activation <{activation}> in tf.keras.activations...")
self.n_matmul_splits = n_matmul_splits
self.kwargs = kwargs

# Rescale Laplacian and store as a TF sparse tensor. Copy to not modify the shared L.
Expand Down Expand Up @@ -119,10 +122,10 @@ def call(self, input_tensor, training=False):
stack = [x0]

if self.K > 1:
x1 = tf.sparse.sparse_dense_matmul(self.sparse_L, x0)
x1 = utils.split_sparse_dense_matmul(self.sparse_L, x0, self.n_matmul_splits)
stack.append(x1)
for k in range(2, self.K):
x2 = 2 * tf.sparse.sparse_dense_matmul(self.sparse_L, x1) - x0 # M x Fin*N
x2 = 2 * utils.split_sparse_dense_matmul(self.sparse_L, x1, self.n_matmul_splits) - x0 # M x Fin*N
stack.append(x2)
x0, x1 = x1, x2
x = tf.stack(stack, axis=0)
Expand Down Expand Up @@ -150,7 +153,7 @@ class Monomial(Model):
A graph convolutional layer using Monomials
"""
def __init__(self, L, K, Fout=None, initializer=None, activation=None, use_bias=False,
use_bn=False, **kwargs):
use_bn=False, n_matmul_splits=1, **kwargs):
"""
Initializes the graph convolutional layer, assuming the input has dimension (B, M, F)
:param L: The graph Laplacian (MxM), as numpy array
Expand All @@ -160,6 +163,8 @@ def __init__(self, L, K, Fout=None, initializer=None, activation=None, use_bias=
:param activation: the activation function to use after the layer, defaults to linear
:param use_bias: Use learnable bias weights
:param use_bn: Apply batch norm before adding the bias
:param n_matmul_splits: Number of splits to apply to axis 1 of the dense tensor in the
tf.sparse.sparse_dense_matmul operations to avoid the operation's size limitation
:param kwargs: additional keyword arguments passed on to add_weight
"""

Expand All @@ -181,6 +186,7 @@ def __init__(self, L, K, Fout=None, initializer=None, activation=None, use_bias=
self.activation = getattr(tf.keras.activations, activation)
else:
raise ValueError(f"Could not find activation <{activation}> in tf.keras.activations...")
self.n_matmul_splits = n_matmul_splits
self.kwargs = kwargs

# Rescale Laplacian and store as a TF sparse tensor. Copy to not modify the shared L.
Expand Down Expand Up @@ -254,7 +260,7 @@ def call(self, input_tensor, training=False):
stack = [x0]

for k in range(1, self.K):
x1 = tf.sparse.sparse_dense_matmul(self.sparse_L, x0) # M x Fin*N
x1 = utils.split_sparse_dense_matmul(self.sparse_L, x0, self.n_matmul_splits) # M x Fin*N
stack.append(x1)
x0 = x1

Expand Down Expand Up @@ -378,7 +384,7 @@ class Bernstein(Model):
"""

def __init__(self, L, K, Fout=None, initializer=None, activation=None, use_bias=False,
use_bn=False, **kwargs):
use_bn=False, n_matmul_splits=1, **kwargs):
"""
Initializes the graph convolutional layer, assuming the input has dimension (B, M, F)
:param L: The graph Laplacian (MxM), as numpy array
Expand All @@ -388,6 +394,8 @@ def __init__(self, L, K, Fout=None, initializer=None, activation=None, use_bias=
:param activation: the activation function to use after the layer, defaults to linear
:param use_bias: Use learnable bias weights
:param use_bn: Apply batch norm before adding the bias
:param n_matmul_splits: Number of splits to apply to axis 1 of the dense tensor in the
tf.sparse.sparse_dense_matmul operations to avoid the operation's size limitation
:param kwargs: additional keyword arguments passed on to add_weight
"""

Expand All @@ -409,6 +417,7 @@ def __init__(self, L, K, Fout=None, initializer=None, activation=None, use_bias=
self.activation = getattr(tf.keras.activations, activation)
else:
raise ValueError(f"Could not find activation <{activation}> in tf.keras.activations...")
self.n_matmul_splits = n_matmul_splits
self.kwargs = kwargs

# Rescale Laplacian and store as a TF sparse tensor. Copy to not modify the shared L.
Expand Down Expand Up @@ -488,11 +497,11 @@ def call(self, input_tensor, training=False, *args, **kwargs):
x1 = x0
theta = comb(self.K,i)/(2**self.K)
for j in range(i):
x2= tf.sparse.sparse_dense_matmul(self.sparse_L, x1)
x2= utils.split_sparse_dense_matmul(self.sparse_L, x1, self.n_matmul_splits)
x1 =x2
x2=x1
for k in range(self.K-i):
x3 = 2*x2-tf.sparse.sparse_dense_matmul(self.sparse_L, x2)
x3 = 2 * x2 - utils.split_sparse_dense_matmul(self.sparse_L, x2, self.n_matmul_splits)
x2 =x3
x3 = theta*x3
stack.append(x3)
Expand Down
23 changes: 16 additions & 7 deletions deepsphere/healpy_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,16 +204,18 @@ def __init__(self, K, Fout=None, initializer=None, activation=None, use_bias=Fal
self.use_bn = use_bn
self.kwargs = kwargs

def _get_layer(self, L):
def _get_layer(self, L, n_matmul_splits=1):
"""
initializes the actual layer, should be called once the graph Laplacian has been calculated
:param L: the graph laplacian
:param n_matmul_splits: Number of splits to apply to axis 1 of the dense tensor in the
tf.sparse.sparse_dense_matmul operations to avoid the operation's size limitation
:return: Chebyshev5 layer that can be called
"""

# now we init the layer
return Chebyshev(L=L, K=self.K, Fout=self.Fout, initializer=self.initializer, activation=self.activation,
use_bias=self.use_bias, use_bn=self.use_bn, **self.kwargs)
use_bias=self.use_bias, use_bn=self.use_bn, n_matmul_splits=n_matmul_splits, **self.kwargs)


class HealpyMonomial():
Expand Down Expand Up @@ -241,16 +243,18 @@ def __init__(self, K, Fout=None, initializer=None, activation=None, use_bias=Fal
self.use_bn = use_bn
self.kwargs = kwargs

def _get_layer(self, L):
def _get_layer(self, L, n_matmul_splits=1):
"""
initializes the actual layer, should be called once the graph Laplacian has been calculated
:param L: the graph laplacian
:param n_matmul_splits: Number of splits to apply to axis 1 of the dense tensor in the
tf.sparse.sparse_dense_matmul operations to avoid the operation's size limitation
:return: Monomial layer that can be called
"""

# now we init the layer
return Monomial(L=L, K=self.K, Fout=self.Fout, initializer=self.initializer, activation=self.activation,
use_bias=self.use_bias, use_bn=self.use_bn, **self.kwargs)
use_bias=self.use_bias, use_bn=self.use_bn, n_matmul_splits=n_matmul_splits, **self.kwargs)


class Healpy_ResidualLayer():
Expand Down Expand Up @@ -285,14 +289,17 @@ def __init__(self, layer_type, layer_kwargs, activation=None, act_before=False,
self.bn_kwargs = bn_kwargs
self.alpha = alpha

def _get_layer(self, L):
def _get_layer(self, L, n_matmul_splits=1):
"""
initializes the actual layer, should be called once the graph Laplacian has been calculated
:param L: the graph laplacian
:param n_matmul_splits: Number of splits to apply to axis 1 of the dense tensor in the
tf.sparse.sparse_dense_matmul operations to avoid the operation's size limitation
:return: GCNN_ResidualLayer layer that can be called
"""
# we add the graph laplacian to all kwargs
self.layer_kwargs.update({"L": L})
self.layer_kwargs.update({"n_matmul_splits": n_matmul_splits})

return GCNN_ResidualLayer(layer_type=self.layer_type, layer_kwargs=self.layer_kwargs,
activation=self.activation, act_before=self.act_before,
Expand Down Expand Up @@ -393,14 +400,16 @@ def __init__(self, K, Fout=None, initializer=None, activation=None, use_bias=Fal
self.use_bn = use_bn
self.kwargs = kwargs

def _get_layer(self, L):
def _get_layer(self, L, n_matmul_splits=1):
"""
initializes the actual layer, should be called once the graph Laplacian has been calculated
:param L: the graph laplacian
:param n_matmul_splits: Number of splits to apply to axis 1 of the dense tensor in the
tf.sparse.sparse_dense_matmul operations to avoid the operation's size limitation
:return: Chebyshev5 layer that can be called
"""

# now we init the layer
return Bernstein(L=L, K=self.K, Fout=self.Fout, initializer=self.initializer, activation=self.activation,
use_bias=self.use_bias, use_bn=self.use_bn, **self.kwargs)
use_bias=self.use_bias, use_bn=self.use_bn, n_matmul_splits=n_matmul_splits, **self.kwargs)

31 changes: 30 additions & 1 deletion deepsphere/healpy_networks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import matplotlib.pyplot as plt
import math
from tensorflow.keras.models import Sequential
import healpy as hp
from pygsp.graphs import SphereHealpix
Expand All @@ -16,14 +17,18 @@ class HealpyGCNN(Sequential):
A graph convolutional network using the Keras model API and the layers from the model
"""

def __init__(self, nside, indices, layers, n_neighbors=8):
def __init__(self, nside, indices, layers, n_neighbors=8, max_batch_size=None):
"""
Initializes a graph convolutional neural network using the healpy pixelization scheme
:param nside: integeger, the nside of the input
:param indices: 1d array of inidices, corresponding to the pixel ids of the input of the NN
:param layers: a list of layers that will make up the neural network
:param n_neighbors: Number of neighbors considered when building the graph, currently supported values are:
8 (default), 20, 40 and 60.
:param max_batch_size: Maximal batch size this network is supposed to handle. This determines the number of
splits in the tf.sparse.sparse_dense_matmul operation, which are subsequently applied
independent of the actual batch size. Defaults to None, then no such precautions are
taken, which may cause an error.
"""
# This is necessary for every Layer
super(HealpyGCNN, self).__init__(name='')
Expand Down Expand Up @@ -80,6 +85,9 @@ def __init__(self, nside, indices, layers, n_neighbors=8):
current_nside = self.nside_in
current_indices = indices

# in general, the feature dimension of the input is unknown
current_Fin = None

for layer in self.layers_in:
if isinstance(layer, (hp_nn.HealpyChebyshev, hp_nn.HealpyMonomial, hp_nn.Healpy_ResidualLayer,
hp_nn.Healpy_Transformer,hp_nn.HealpyBernstein)):
Expand All @@ -90,6 +98,21 @@ def __init__(self, nside, indices, layers, n_neighbors=8):
current_A = sphere.A
if isinstance(layer, hp_nn.Healpy_Transformer):
actual_layer = layer._get_layer(current_A)
elif isinstance(layer, (hp_nn.HealpyChebyshev, hp_nn.HealpyMonomial, hp_nn.HealpyBernstein,
hp_nn.Healpy_ResidualLayer)):
if (max_batch_size is not None) and (current_Fin is not None):
n_matmul_splits = 1
while not (
# tf.split only does even splits for integer arguments
(max_batch_size * current_Fin % n_matmul_splits == 0) and
# due to tf.sparse.sparse_dense_matmul
(n_matmul_splits >= max_batch_size * current_Fin * len(current_L.indices) / 2**31)
):
n_matmul_splits += 1
actual_layer = layer._get_layer(current_L, n_matmul_splits)

else:
actual_layer = layer._get_layer(current_L)
else:
actual_layer = layer._get_layer(current_L)
self.layers_use.append(actual_layer)
Expand All @@ -110,6 +133,12 @@ def __init__(self, nside, indices, layers, n_neighbors=8):
else:
self.layers_use.append(layer)

try:
current_Fin = layer.Fout
except AttributeError:
# don't update, this is for example the case for residual or pooling layers that have Fin = Fout
pass

# Now that we have everything we can super init...
super(HealpyGCNN, self).__init__(layers=self.layers_use)

Expand Down
30 changes: 30 additions & 0 deletions deepsphere/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from scipy import sparse
import healpy as hp
import tensorflow as tf


def extend_indices(indices, nside_in, nside_out, nest=True):
Expand Down Expand Up @@ -42,3 +43,32 @@ def rescale_L(L, lmax=2, scale=1):
L *= 2 * scale / lmax
L -= I
return L

@tf.function
def split_sparse_dense_matmul(sparse_tensor, dense_tensor, n_splits=1):
"""
Splits axis 1 of the dense_tensor such that tensorflow can handle the size of the computation.
:param sparse_tensor: Input sparse tensor of rank 2.
:param dense_tensor: Input dense tensor of rank 2.
:param n_splits: Integer number of splits applied to axis 1 of dense_tensor.
For reference, the error message to be avoided is:
'Cannot use GPU when output.shape[1] * nnz(a) > 2^31 [Op:SparseTensorDenseMatMul]
Call arguments received by layer "chebyshev" (type Chebyshev):
• input_tensor=tf.Tensor(shape=(208, 7264, 128), dtype=float32)
• training=False'
"""
if n_splits > 1:
print(f"Tracing... Due to tensor size, tf.sparse.sparse_dense_matmul is executed over {n_splits} splits."
f" Beware of the resulting performance penalty.")
dense_splits = tf.split(dense_tensor, n_splits, axis=1)
result = []
for dense_split in dense_splits:
result.append(tf.sparse.sparse_dense_matmul(sparse_tensor, dense_split))
result = tf.concat(result, axis=1)
else:
result = tf.sparse.sparse_dense_matmul(sparse_tensor, dense_tensor)

return result

0 comments on commit 3facedf

Please sign in to comment.