Skip to content

Commit

Permalink
Merge pull request #100 from gsivori/sample_positions_with_holes
Browse files Browse the repository at this point in the history
Position sampling in Environments with holes.
  • Loading branch information
TomGeorge1234 authored Jan 29, 2024
2 parents f317916 + f224975 commit cd429d8
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions ratinabox/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,6 @@ def sample_positions(self, n=10, method="uniform_jitter"):
Args:
n (int): number of features
method: "uniform", "uniform_jittered" or "random" for how points are distributed
true_random: if True, just randomly scatters point
Returns:
array: (n x dimensionality) of positions
"""
Expand All @@ -581,33 +580,44 @@ def sample_positions(self, n=10, method="uniform_jitter"):
positions[:, 1] = np.random.uniform(
self.extent[2], self.extent[3], size=n
)
if (self.is_rectangular is False) or (self.has_holes is True):
# in this case, the positions you have sampled within the extent of the environment may not actually fall within it's legal area (i.e. they could be outside the polygon boundary or inside a hole). Brute force this by randomly resampling these points until all fall within the env.
for i, pos in enumerate(positions):
if self.check_if_position_is_in_environment(pos) == False:
pos = self.sample_positions(n=1, method="random").reshape(
-1
) # this recursive call must pass eventually, assuming the env is sufficiently large. this is why we don't need a while loop
positions[i] = pos
elif method[:7] == "uniform":
ex = self.extent
area = (ex[1] - ex[0]) * (ex[3] - ex[2])
if (self.has_holes is True):
area -= sum(shapely.geometry.Polygon(hole).area for hole in self.holes)
delta = np.sqrt(area / n)
x = np.linspace(ex[0] + delta /2, ex[1] - delta /2, int((ex[1] - ex[0])/delta))
y = np.linspace(ex[2] + delta /2, ex[3] - delta /2, int((ex[3] - ex[2])/delta))
positions = np.array(np.meshgrid(x, y)).reshape(2, -1).T

if (self.is_rectangular is False) or (self.has_holes is True):
# in this case, the positions you have sampled within the extent of the environment may not actually fall within it's legal area (i.e. they could be outside the polygon boundary or inside a hole).
delpos = [i for (i,pos) in enumerate(positions) if self.check_if_position_is_in_environment(pos) == False]
positions = np.delete(positions,delpos,axis=0) # this will delete illegal positions

n_uniformly_distributed = positions.shape[0]
if method[7:] == "_jitter":
positions += np.random.uniform(
-0.45 * delta, 0.45 * delta, positions.shape
)
)
n_remaining = n - n_uniformly_distributed
if n_remaining > 0:
positions_remaining = self.sample_positions(
n=n_remaining, method="random"
# sample remaining from available positions with further jittering (delta = delta/2)
positions_remaining = np.array([positions[i] for i in np.random.choice(range(len(positions)),n_remaining, replace=False)])
delta /= 2
positions_remaining += np.random.uniform(
-0.45 * delta, 0.45 * delta, positions_remaining.shape
)
positions = np.vstack((positions, positions_remaining))

if (self.is_rectangular is False) or (self.has_holes is True):
# in this case, the positions you have sampled within the extent of the environment may not actually fall within it's legal area (i.e. they could be outside the polygon boundary or inside a hole). Brute force this by randomly resampling these points until all fall within the env.
for i, pos in enumerate(positions):
if self.check_if_position_is_in_environment(pos) == False:
pos = self.sample_positions(n=1, method="random").reshape(
-1
) # this recursive call must pass eventually, assuming the env is sufficiently large. this is why we don't need a while loop
positions[i] = pos
return positions

def discretise_environment(self, dx=None):
Expand Down

0 comments on commit cd429d8

Please sign in to comment.