Skip to content

Commit e4f8659

Browse files
authored
Merge pull request #1199 from ErwanH29/nearest_neighbour_fix
Resolving issue #978: To ensure particles do not consider themselves when calculating nearest neighbours
2 parents c6cb09d + 584b8bf commit e4f8659

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

examples/textbook/relax_gas_and_stars.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
import numpy
22
import pickle
3-
from amuse.lab import *
3+
4+
from amuse.lab import (
5+
units, nbody_system,
6+
new_plummer_model, write_set_to_file,
7+
new_salpeter_mass_distribution,
8+
)
49
from amuse.community.fastkick.interface import FastKick
5-
from amuse.ext.relax_sph import relax
10+
from amuse.community.fi.interface import Fi
11+
from amuse.ext.relax_sph import relax, monitor_energy
612
from amuse.ext.spherical_model import new_gas_plummer_distribution
713
from amuse.community.fractalcluster.interface import new_fractal_cluster_model
814

15+
16+
from prepare_figure import *
17+
918
###BOOKLISTSTART1###
1019
def check_energy_conservation(system, i_step, time, n_steps):
1120
unit = units.J
@@ -163,7 +172,6 @@ def make_map(sph,N=100,L=1):
163172
def plot_hydro_and_stars(hydro, stars):
164173
x_label = "x [pc]"
165174
y_label = "y [pc]"
166-
from prepare_figure import *
167175
fig = single_frame(x_label, y_label, logx=False, logy=False,
168176
xsize=12, ysize=12)
169177

src/amuse/datamodel/particle_attributes.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -923,14 +923,15 @@ def distances_squared(particles, other_particles):
923923
return (dxdydz**2).sum(-1)
924924

925925

926-
def nearest_neighbour(particles, neighbours=None, max_array_length=10000000):
926+
def nearest_neighbour(particles, neighbours=None, self_search=False, max_array_length=10000000):
927927
"""
928928
Returns the nearest neighbour of each particle in this set. If the 'neighbours'
929929
particle set is supplied, the search is performed on the neighbours set, for
930930
each particle in the orignal set. Otherwise the nearest neighbour in the same
931931
set is searched.
932932
933933
:argument neighbours: the particle set in which to search for the nearest neighbour (optional)
934+
:argument self_search: if True, the nearest neighbour can be the particle itself (default False)
934935
935936
>>> from amuse.datamodel import Particles
936937
>>> particles = Particles(3)
@@ -964,7 +965,7 @@ def nearest_neighbour(particles, neighbours=None, max_array_length=10000000):
964965
)
965966
for indices in indices_in_each_batch:
966967
distances_squared = particles[indices].distances_squared(other_particles)
967-
if neighbours is None:
968+
if not self_search and neighbours is None:
968969
diagonal_indices = (numpy.arange(len(indices)), indices)
969970
distances_squared.number[
970971
diagonal_indices
@@ -973,11 +974,11 @@ def nearest_neighbour(particles, neighbours=None, max_array_length=10000000):
973974
return other_particles[numpy.concatenate(neighbour_indices)]
974975

975976
distances_squared = particles.distances_squared(other_particles)
976-
if neighbours is None:
977+
if not self_search and neighbours is None: # can't be your own neighbour
977978
diagonal_indices = numpy.diag_indices(len(particles))
978979
distances_squared.number[
979980
diagonal_indices
980-
] = numpy.inf # can't be your own neighbour
981+
] = numpy.inf
981982
return other_particles[distances_squared.argmin(axis=1)]
982983

983984

src/tests/core_tests/test_particle_attributes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def test11(self):
173173
particles.z = 0.0 | units.m
174174
self.assertEqual(particles.nearest_neighbour()[0], particles[1])
175175
self.assertEqual(particles.nearest_neighbour()[1:].key, particles[:-1].key)
176+
self.assertEqual(particles.nearest_neighbour(self_search=True).key, particles.key)
176177

177178
neighbours = Particles(3)
178179
neighbours.x = [1.0, 10.0, 100.0] | units.m

0 commit comments

Comments
 (0)