Skip to content

chebyshev radial spectrum #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion spex/radial/simple/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bernstein import Bernstein
from .simple import Simple
from .chebyshev import Chebyshev

__all__ = [Bernstein, Simple]
__all__ = [Chebyshev, Bernstein, Simple]
64 changes: 64 additions & 0 deletions spex/radial/simple/chebyshev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np
import torch

import scipy

from .simple import Simple


class Chebyshev(Simple):
"""Chebyshev Polynomial basis.

The basis is optionally be transformed with a learned linear layer (``trainable=True``),
optionally with a separate transformation per degree (``per_degree=True``).
The target number of features is specified by ``num_features``.

Attributes:
cutoff (Tensor): Cutoff distance.
max_angular (int): Maximum spherical harmonic order.
n_per_l (list): Number of features per degree.

"""

def __init__(self, *args, **kwargs):
"""Initialise the Chebyshev basis.

Args:
cutoff (float): Cutoff distance.
num_radial (int): Number of radial basis functions.
max_angular (int): Maximum spherical harmonic order.
trainable (bool, optional): Whether a learned linear transformation is
applied.
per_degree (bool, optional): Whether to have a separate learned transform
per degree.
num_features (int, optional): Target number of features for learned
transformation. Defaults to ``num_radial``.

"""
super().__init__(*args, **kwargs)

n = self.num_radial
v = np.arange(n)
self.register_buffer("v", torch.from_numpy(v))

def expand(self, r):
"""Compute the Chebyshev polynomial basis.

Args:
r (Tensor): Input distances of shape ``[pair]``.

Returns:
Expansion of shape ``[pair, num_radial]``.
"""
r = r.unsqueeze(-1) / self.cutoff
k = self.v

mask0 = r < 0
mask1 = r > 1
mask = torch.logical_or(mask0, mask1)
y = torch.where(
mask,
0,
torch.cos(k * torch.arccos(2 * r - 1)))

return y
54 changes: 54 additions & 0 deletions tests/test_chebyshev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import torch

from unittest import TestCase


class TestChebyshev(TestCase):
"""Basic test suite for the Chebyshev class."""

def setUp(self):
self.num_radial = 128
self.cutoff = 5.0
self.max_angular = 3
self.num_features = None
self.trainable = False
self.per_degree = False

self.r = np.random.random(25)

def test_jit(self):
"""Test if Chebyshev class works with TorchScript."""
from spex.radial.simple import Chebyshev

radial = Chebyshev(
cutoff=self.cutoff,
num_radial=self.num_radial,
max_angular=self.max_angular,
num_features=self.num_features,
trainable=self.trainable,
per_degree=self.per_degree,
)
radial = torch.jit.script(radial)
radial(torch.tensor(self.r, dtype=torch.float32))

def test_hardcoded(self):
"""Test Chebyshev class with hardcoded parameters."""
from spex.radial.simple import Chebyshev

radial = Chebyshev(
cutoff=self.cutoff,
num_radial=self.num_radial,
max_angular=self.max_angular,
num_features=self.num_features,
)

assert radial.cutoff == self.cutoff
assert radial.num_radial == self.num_radial
assert radial.max_angular == self.max_angular

# Validate function output against reference implementation
torch_r = torch.tensor(self.r, dtype=torch.float32)
torch_output = radial.expand(torch_r).detach().numpy()

np.testing.assert_allclose(torch_output.shape, (self.r.shape[0], self.num_radial))