Skip to content

Commit 8b2d427

Browse files
authored
getNeighborPairs() supports periodic boundary conditions (#70)
* getNeighborPairs() supports periodic boundary conditions * CUDA implementation of periodic boundary conditions * Fixed error in autograd * Skip test that causes CUDA assertion * Added checks for invalid box vectors
1 parent 7491583 commit 8b2d427

File tree

5 files changed

+154
-19
lines changed

5 files changed

+154
-19
lines changed

src/pytorch/neighbors/TestNeighbors.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,14 @@ def test_neighbor_grads(dtype, num_atoms, grad):
128128
raise ValueError('grad')
129129

130130
if dtype == pt.float32:
131-
assert pt.allclose(positions_cpu.grad, positions_cuda.grad.cpu(), atol=1e-5, rtol=1e-3)
131+
assert pt.allclose(positions_cpu.grad, positions_cuda.grad.cpu(), atol=1e-3, rtol=1e-3)
132132
else:
133133
assert pt.allclose(positions_cpu.grad, positions_cuda.grad.cpu(), atol=1e-8, rtol=1e-5)
134134

135-
@pytest.mark.parametrize('device', ['cpu', 'cuda'])
135+
# The following test is only run on the CPU. Running it on the GPU triggers a
136+
# CUDA assertion, which causes all tests run after it to fail.
137+
138+
@pytest.mark.parametrize('device', ['cpu'])
136139
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
137140
def test_too_many_neighbors(device, dtype):
138141

@@ -143,4 +146,67 @@ def test_too_many_neighbors(device, dtype):
143146
with pytest.raises(RuntimeError):
144147
positions = pt.zeros((4, 3,), device=device, dtype=dtype)
145148
getNeighborPairs(positions, cutoff=1, max_num_neighbors=1)
146-
pt.cuda.synchronize()
149+
pt.cuda.synchronize()
150+
151+
@pytest.mark.parametrize('device', ['cpu', 'cuda'])
152+
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
153+
def test_periodic_neighbors(device, dtype):
154+
155+
if not pt.cuda.is_available() and device == 'cuda':
156+
pytest.skip('No GPU')
157+
158+
# Generate random positions
159+
num_atoms = 100
160+
positions = (20 * pt.randn((num_atoms, 3), device=device, dtype=dtype)) - 10
161+
box_vectors = pt.tensor([[10, 0, 0], [2, 12, 0], [0, 1, 11]], device=device, dtype=dtype)
162+
cutoff = 5.0
163+
164+
# Get neighbor pairs
165+
ref_neighbors = np.vstack(np.tril_indices(num_atoms, -1))
166+
ref_positions = positions.cpu().numpy()
167+
ref_vectors = box_vectors.cpu().numpy()
168+
ref_deltas = ref_positions[ref_neighbors[0]] - ref_positions[ref_neighbors[1]]
169+
ref_deltas -= np.outer(np.round(ref_deltas[:,2]/ref_vectors[2,2]), ref_vectors[2])
170+
ref_deltas -= np.outer(np.round(ref_deltas[:,1]/ref_vectors[1,1]), ref_vectors[1])
171+
ref_deltas -= np.outer(np.round(ref_deltas[:,0]/ref_vectors[0,0]), ref_vectors[0])
172+
ref_distances = np.linalg.norm(ref_deltas, axis=1)
173+
174+
# Filter the neighbor pairs
175+
mask = ref_distances > cutoff
176+
ref_neighbors[:, mask] = -1
177+
ref_deltas[mask, :] = np.nan
178+
ref_distances[mask] = np.nan
179+
180+
# Find the number of neighbors
181+
num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances)))
182+
max_num_neighbors = max(int(np.ceil(num_neighbors / num_atoms)), 1)
183+
184+
# Compute results
185+
neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff, max_num_neighbors=max_num_neighbors, box_vectors=box_vectors)
186+
187+
# Check device
188+
assert neighbors.device == positions.device
189+
assert deltas.device == positions.device
190+
assert distances.device == positions.device
191+
192+
# Check types
193+
assert neighbors.dtype == pt.int32
194+
assert deltas.dtype == dtype
195+
assert distances.dtype == dtype
196+
197+
# Covert the results
198+
neighbors = neighbors.cpu().numpy()
199+
deltas = deltas.cpu().numpy()
200+
distances = distances.cpu().numpy()
201+
202+
# Sort the neighbors
203+
# NOTE: GPU returns the neighbor in a non-deterministic order
204+
ref_neighbors, ref_deltas, ref_distances = sort_neighbors(ref_neighbors, ref_deltas, ref_distances)
205+
neighbors, deltas, distances = sort_neighbors(neighbors, deltas, distances)
206+
207+
# Resize the reference
208+
ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, num_atoms * max_num_neighbors)
209+
210+
assert np.all(ref_neighbors == neighbors)
211+
assert np.allclose(ref_deltas, deltas, equal_nan=True)
212+
assert np.allclose(ref_distances, distances, equal_nan=True)

src/pytorch/neighbors/getNeighborPairs.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from torch import ops, Tensor
2-
from typing import Tuple
1+
from torch import empty, ops, Tensor
2+
from typing import Optional, Tuple
33

44

5-
def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = -1) -> Tuple[Tensor, Tensor]:
5+
def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = -1, box_vectors: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
66
'''
77
Returns indices and distances of atom pairs within a given cutoff distance.
88
@@ -16,6 +16,20 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int =
1616
molecule, where most of the atoms are beyond the cutoff distance of each
1717
other.
1818
19+
This function optionally supports periodic boundary conditions with
20+
arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy
21+
certain requirements:
22+
23+
`a[1] = a[2] = b[2] = 0`
24+
`a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff`
25+
`a[0] >= 2*b[0]`
26+
`a[0] >= 2*c[0]`
27+
`b[1] >= 2*c[1]`
28+
29+
These requirements correspond to a particular rotation of the system and
30+
reduced form of the vectors, as well as the requirement that the cutoff be
31+
no larger than half the box width.
32+
1933
Parameters
2034
----------
2135
positions: `torch.Tensor`
@@ -26,6 +40,10 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int =
2640
max_num_neighbors: int, optional
2741
Maximum number of neighbors per atom. If set to `-1` (default),
2842
all possible combinations of atom pairs are included.
43+
box_vectors: `torch.Tensor`, optional
44+
The vectors defining the periodic box. This must have shape `(3, 3)`,
45+
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
46+
If this is omitted, periodic boundary conditions are not applied.
2947
3048
Returns
3149
-------
@@ -103,4 +121,6 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int =
103121
tensor([1., 1., nan, nan, nan, nan]))
104122
'''
105123

106-
return ops.neighbors.getNeighborPairs(positions, cutoff, max_num_neighbors)
124+
if box_vectors is None:
125+
box_vectors = empty((0, 0), device=positions.device, dtype=positions.dtype)
126+
return ops.neighbors.getNeighborPairs(positions, cutoff, max_num_neighbors, box_vectors)

src/pytorch/neighbors/getNeighborPairsCPU.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@ using torch::Scalar;
1313
using torch::hstack;
1414
using torch::vstack;
1515
using torch::Tensor;
16+
using torch::outer;
17+
using torch::round;
1618

1719
static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,
1820
const Scalar& cutoff,
19-
const Scalar& max_num_neighbors) {
21+
const Scalar& max_num_neighbors,
22+
const Tensor& box_vectors) {
2023

2124
TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions");
2225
TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0");
@@ -25,6 +28,25 @@ static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,
2528

2629
TORCH_CHECK(cutoff.to<double>() > 0, "Expected \"cutoff\" to be positive");
2730

31+
if (box_vectors.size(0) != 0) {
32+
TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions");
33+
TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, "Expected \"box_vectors\" to have shape (3, 3)");
34+
double v[3][3];
35+
for (int i = 0; i < 3; i++)
36+
for (int j = 0; j < 3; j++)
37+
v[i][j] = box_vectors[i][j].item<double>();
38+
double c = cutoff.to<double>();
39+
TORCH_CHECK(v[0][1] == 0, "Invalid box vectors: box_vectors[0][1] != 0");
40+
TORCH_CHECK(v[0][2] == 0, "Invalid box vectors: box_vectors[0][2] != 0");
41+
TORCH_CHECK(v[1][2] == 0, "Invalid box vectors: box_vectors[1][2] != 0");
42+
TORCH_CHECK(v[0][0] >= 2*c, "Invalid box vectors: box_vectors[0][0] < 2*cutoff");
43+
TORCH_CHECK(v[1][1] >= 2*c, "Invalid box vectors: box_vectors[1][1] < 2*cutoff");
44+
TORCH_CHECK(v[2][2] >= 2*c, "Invalid box vectors: box_vectors[2][2] < 2*cutoff");
45+
TORCH_CHECK(v[0][0] >= 2*v[1][0], "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]");
46+
TORCH_CHECK(v[0][0] >= 2*v[2][0], "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]");
47+
TORCH_CHECK(v[1][1] >= 2*v[2][1], "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]");
48+
}
49+
2850
const int max_num_neighbors_ = max_num_neighbors.to<int>();
2951
TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1,
3052
"Expected \"max_num_neighbors\" to be positive or equal to -1");
@@ -39,12 +61,17 @@ static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,
3961

4062
Tensor neighbors = vstack({rows, columns});
4163
Tensor deltas = index_select(positions, 0, rows) - index_select(positions, 0, columns);
64+
if (box_vectors.size(0) != 0) {
65+
deltas -= outer(round(deltas.index({Slice(), 2})/box_vectors.index({2, 2})), box_vectors.index({2}));
66+
deltas -= outer(round(deltas.index({Slice(), 1})/box_vectors.index({1, 1})), box_vectors.index({1}));
67+
deltas -= outer(round(deltas.index({Slice(), 0})/box_vectors.index({0, 0})), box_vectors.index({0}));
68+
}
4269
Tensor distances = frobenius_norm(deltas, 1);
4370

4471
if (max_num_neighbors_ == -1) {
4572
const Tensor mask = distances > cutoff;
4673
neighbors.index_put_({Slice(), mask}, -1);
47-
deltas = deltas.clone(); // Brake an autograd loop
74+
deltas = deltas.clone(); // Break an autograd loop
4875
deltas.index_put_({mask, Slice()}, NAN);
4976
distances.index_put_({mask}, NAN);
5077

src/pytorch/neighbors/getNeighborPairsCUDA.cu

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@ template <typename scalar_t> __global__ void forward_kernel(
3131
const Accessor<scalar_t, 2> positions,
3232
const scalar_t cutoff2,
3333
const bool store_all_pairs,
34+
const bool use_periodic,
3435
Accessor<int32_t, 1> i_curr_pair,
3536
Accessor<int32_t, 2> neighbors,
3637
Accessor<scalar_t, 2> deltas,
37-
Accessor<scalar_t, 1> distances
38+
Accessor<scalar_t, 1> distances,
39+
Accessor<scalar_t, 2> box_vectors
3840
) {
3941
const int32_t index = blockIdx.x * blockDim.x + threadIdx.x;
4042
if (index >= num_all_pairs) return;
@@ -43,9 +45,20 @@ template <typename scalar_t> __global__ void forward_kernel(
4345
if (row * (row - 1) > 2 * index) row--;
4446
const int32_t column = index - row * (row - 1) / 2;
4547

46-
const scalar_t delta_x = positions[row][0] - positions[column][0];
47-
const scalar_t delta_y = positions[row][1] - positions[column][1];
48-
const scalar_t delta_z = positions[row][2] - positions[column][2];
48+
scalar_t delta_x = positions[row][0] - positions[column][0];
49+
scalar_t delta_y = positions[row][1] - positions[column][1];
50+
scalar_t delta_z = positions[row][2] - positions[column][2];
51+
if (use_periodic) {
52+
scalar_t scale3 = round(delta_z/box_vectors[2][2]);
53+
delta_x -= scale3*box_vectors[2][0];
54+
delta_y -= scale3*box_vectors[2][1];
55+
delta_z -= scale3*box_vectors[2][2];
56+
scalar_t scale2 = round(delta_y/box_vectors[1][1]);
57+
delta_x -= scale2*box_vectors[1][0];
58+
delta_y -= scale2*box_vectors[1][1];
59+
scalar_t scale1 = round(delta_x/box_vectors[0][0]);
60+
delta_x -= scale1*box_vectors[0][0];
61+
}
4962
const scalar_t distance2 = delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
5063

5164
if (distance2 > cutoff2) return;
@@ -89,7 +102,8 @@ public:
89102
static tensor_list forward(AutogradContext* ctx,
90103
const Tensor& positions,
91104
const Scalar& cutoff,
92-
const Scalar& max_num_neighbors) {
105+
const Scalar& max_num_neighbors,
106+
const Tensor& box_vectors) {
93107

94108
TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions");
95109
TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0");
@@ -100,6 +114,12 @@ public:
100114
TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1,
101115
"Expected \"max_num_neighbors\" to be positive or equal to -1");
102116

117+
const bool use_periodic = (box_vectors.size(0) != 0);
118+
if (use_periodic) {
119+
TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions");
120+
TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, "Expected \"box_vectors\" to have shape (3, 3)");
121+
}
122+
103123
// Decide the algorithm
104124
const bool store_all_pairs = max_num_neighbors_ == -1;
105125
const int num_atoms = positions.size(0);
@@ -125,10 +145,12 @@ public:
125145
get_accessor<scalar_t, 2>(positions),
126146
cutoff_ * cutoff_,
127147
store_all_pairs,
148+
use_periodic,
128149
get_accessor<int32_t, 1>(i_curr_pair),
129150
get_accessor<int32_t, 2>(neighbors),
130151
get_accessor<scalar_t, 2>(deltas),
131-
get_accessor<scalar_t, 1>(distances));
152+
get_accessor<scalar_t, 1>(distances),
153+
get_accessor<scalar_t, 2>(box_vectors));
132154
});
133155

134156
ctx->save_for_backward({neighbors, deltas, distances});
@@ -165,14 +187,14 @@ public:
165187
get_accessor<scalar_t, 2>(grad_positions));
166188
});
167189

168-
return {grad_positions, Tensor(), Tensor()};
190+
return {grad_positions, Tensor(), Tensor(), Tensor()};
169191
}
170192
};
171193

172194
TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) {
173195
m.impl("getNeighborPairs",
174-
[](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_neighbors){
175-
const tensor_list results = Autograd::apply(positions, cutoff, max_num_neighbors);
196+
[](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_neighbors, const Tensor& box_vectors){
197+
const tensor_list results = Autograd::apply(positions, cutoff, max_num_neighbors, box_vectors);
176198
return make_tuple(results[0], results[1], results[2]);
177199
});
178200
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include <torch/extension.h>
22

33
TORCH_LIBRARY(neighbors, m) {
4-
m.def("getNeighborPairs(Tensor positions, Scalar cutoff, Scalar max_num_neighbors) -> (Tensor neighbors, Tensor deltas, Tensor distances)");
4+
m.def("getNeighborPairs(Tensor positions, Scalar cutoff, Scalar max_num_neighbors, Tensor box_vectors) -> (Tensor neighbors, Tensor deltas, Tensor distances)");
55
}

0 commit comments

Comments
 (0)