Skip to content

Commit a3baa4d

Browse files
committed
only use the memory efficient implementation when not scripting
1 parent 6397724 commit a3baa4d

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

torchmdnet/extensions/ops.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,20 @@ def get_neighbor_pairs_kernel(
8787
num_pairs : Tensor
8888
The number of pairs found.
8989
"""
90-
if USE_MEMORY_EFFICIENT:
91-
return torch_neighbor_bruteforce_memory_efficient(
92-
strategy,
93-
positions,
94-
batch=batch,
95-
in_box_vectors=box_vectors,
96-
use_periodic=use_periodic,
97-
cutoff_lower=cutoff_lower,
98-
cutoff_upper=cutoff_upper,
99-
max_num_pairs=max_num_pairs,
100-
loop=loop,
101-
include_transpose=include_transpose,
102-
)
90+
if not torch.jit.is_scripting():
91+
if USE_MEMORY_EFFICIENT:
92+
return torch_neighbor_bruteforce_memory_efficient(
93+
strategy,
94+
positions,
95+
batch=batch,
96+
in_box_vectors=box_vectors,
97+
use_periodic=use_periodic,
98+
cutoff_lower=cutoff_lower,
99+
cutoff_upper=cutoff_upper,
100+
max_num_pairs=max_num_pairs,
101+
loop=loop,
102+
include_transpose=include_transpose,
103+
)
103104

104105
if torch.jit.is_scripting() or not positions.is_cuda:
105106

0 commit comments

Comments
 (0)